## 0) Testing CNN Backbone Implementation

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F

class densetFPN_121(nn.Module):
    """ DenseNet121-based Feature Pyramid Network (FPN) for feature extraction. 
        Total number of parameters:  8232320 (8.23 million) """ 
    def __init__(self, weights='DEFAULT', common_channel_size=256, output_channel_size=256):
        super(densetFPN_121, self).__init__()
        original_densenet = models.densenet121(weights=weights)
        
        # Initial layers: extract features without modification
        self.encoder = nn.ModuleList([
            nn.Sequential(*list(original_densenet.features.children())[:6], nn.Dropout(0.4)),   # 128x12x12
            nn.Sequential(*list(original_densenet.features.children())[6:8], nn.Dropout(0.4)),  # 256x6x6
            nn.Sequential(*list(original_densenet.features.children())[8:10], nn.Dropout(0.4)), # 896x3x3
            nn.Sequential(*list(original_densenet.features.children())[10:-1], nn.Dropout(0.4)) # 1920x3x3
        ])
        
        # Define convolutional layers for adapting channel sizes
        fpn_channels = [128, 256, 512, 1024]
        self.adaptation_layers = nn.ModuleDict({
            f'adapt{i+1}': nn.Conv2d(fpn_channels[i], common_channel_size, kernel_size=1)
            for i in range(4)
        })

        # Define FPN layers
        self.fpn = nn.ModuleDict({
            f'fpn{i+1}': nn.Conv2d(common_channel_size, common_channel_size, kernel_size=1)
            for i in range(3)
        })

        self.merge_layers = nn.Sequential(
            nn.Conv2d(common_channel_size, output_channel_size, kernel_size=3), # kernel size 1 or 3
            nn.BatchNorm2d(output_channel_size),
            nn.ReLU(),
            nn.Dropout(0.4) # 0.2
        )

    def forward(self, x):
        # Encoder
        features = []
        for encoder in self.encoder:
            x = encoder(x)
            features.append(x)
        
        # Merge channels using 1x1 convolutions
        adapted_features = [self.adaptation_layers[f'adapt{i+1}'](features[i]) for i in range(4)]
        
        # FPN integration using top-down pathway
        fpn_output = adapted_features.pop()  # Start with the deepest features
        for i in reversed(range(3)):
            upsampled = F.interpolate(fpn_output, size=adapted_features[i].shape[-2:], mode='nearest')
            fpn_output = self.fpn[f'fpn{i+1}'](upsampled + adapted_features[i])
        
        # Merge features
        merged_features = self.merge_layers(fpn_output)
        
        return merged_features


model = densetFPN_121(weights=None)

# Print total number of parameters
total_params = sum(p.numel() for p in model.parameters())
print("Total number of parameters: ", total_params)

# Create a dummy input tensor of size [50, 3, 100, 100]
dummy_input = torch.randn(50, 3, 100, 100)

# Forward pass through the model with dummy input
features = model(dummy_input)

# Print output shapes to verify
print("Features shape:", features.shape)

# print(features.conv_info())

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

class efficientDecoder_v2_s(nn.Module):
    def __init__(self, num_channels=[24, 48, 64, 128, 160, 256], output_channel_size=256, output_feature_size=25):
        super(efficientDecoder_v2_s, self).__init__()
        # Load EfficientNet V2 Small features
        efficientnet_v2_s = models.efficientnet_v2_s(weights='DEFAULT').features[:-1]

        # Modularize encoders
        self.encoders = nn.ModuleList([
            nn.Sequential(*list(efficientnet_v2_s.children())[:2], nn.Dropout(0.1)),    # 24x50x50
            nn.Sequential(*list(efficientnet_v2_s.children())[2:3], nn.Dropout(0.1)),   # 48x25x25
            nn.Sequential(*list(efficientnet_v2_s.children())[3:4], nn.Dropout(0.2)),   # 64x13x13
            nn.Sequential(*list(efficientnet_v2_s.children())[4:5], nn.Dropout(0.2)),   # 128x7x7
            nn.Sequential(*list(efficientnet_v2_s.children())[5:6], nn.Dropout(0.3)),   # 160x7x7 # TODO: Check whether to skip 128x7x7
            nn.Sequential(*list(efficientnet_v2_s.children())[6:7], nn.Dropout(0.3))    # 256x4x4
        ])
        
        # Modularize upconvolutions
        self.upconvs = nn.ModuleList([
            nn.ConvTranspose2d(in_channels=256, out_channels=160, kernel_size=2, stride=2, padding=1, output_padding=1),
            nn.Conv2d(in_channels=160, out_channels=128, kernel_size=1, stride=1),
            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2, padding=1, output_padding=1),
            nn.ConvTranspose2d(in_channels=64, out_channels=48, kernel_size=2, stride=2, padding=1, output_padding=1),
            nn.ConvTranspose2d(in_channels=48, out_channels=24, kernel_size=2, stride=2)
        ])
        
        # Modularize decoders
        self.decoders = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(num_channels[i] * 2, num_channels[i], kernel_size=3, padding=1),
                nn.BatchNorm2d(num_channels[i]),
                nn.SiLU(),
                nn.Dropout(0.1 + i * 0.05)
            ) for i in range(len(num_channels)-1)
        ])

        # Optional, merge layers to increase the number of channels
        self.merge_layers = nn.Sequential(
            nn.Conv2d(24, output_channel_size, kernel_size=1),
            nn.BatchNorm2d(output_channel_size),
            nn.SiLU(),
            nn.Dropout(0.3)
        )
        
        self.global_avg_pool = nn.AdaptiveAvgPool2d(output_feature_size) # to reduce noise and overfitting

    def forward(self, x):
        # Encoder
        features = []
        for encoder in self.encoders:
            x = encoder(x)
            features.append(x)
        
        # Decoder
        x = features.pop()
        for upconv, decoder, feature in zip(self.upconvs, reversed(self.decoders), reversed(features)):
            x = upconv(x)
            x = torch.cat((x, feature), dim=1)
            x = decoder(x)
        
        x = self.merge_layers(x) # Introduced to increase the number of channels
        pooled_features = self.global_avg_pool(x) # Introduced to reduce noise and overfitting
        
        return pooled_features

model = efficientDecoder_v2_s()

# Print total number of parameters
total_params = sum(p.numel() for p in model.parameters())
print("Total number of parameters: ", total_params)

# Create a dummy input tensor of size [50, 3, 100, 100]
dummy_input = torch.randn(50, 3, 100, 100)

# Forward pass through the model with dummy input
features = model(dummy_input)

# Print output shapes to verify
print("Features shape:", features.shape)

In [None]:
for i in reversed(range(1,6)):
    print(i)

In [None]:
import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F

class efficientDecoder_v2_s(nn.Module):
    def __init__(self, output_channel_size=256, output_feature_size=25):
        super(efficientDecoder_v2_s, self).__init__()
        # Load EfficientNet V2 Small features
        efficientnet_v2_s = models.efficientnet_v2_s(weights='DEFAULT').features[:-1]

        # Modularize encoders
        self.encoders = nn.ModuleList([
            nn.Sequential(*list(efficientnet_v2_s.children())[:2], nn.Dropout(0.1)),    # 24x50x50
            nn.Sequential(*list(efficientnet_v2_s.children())[2:3], nn.Dropout(0.1)),   # 48x25x25
            nn.Sequential(*list(efficientnet_v2_s.children())[3:4], nn.Dropout(0.2)),   # 64x13x13
            nn.Sequential(*list(efficientnet_v2_s.children())[4:5], nn.Dropout(0.2)),   # 128x7x7
            nn.Sequential(*list(efficientnet_v2_s.children())[5:6], nn.Dropout(0.3)),   # 160x7x7
            nn.Sequential(*list(efficientnet_v2_s.children())[6:7], nn.Dropout(0.3))    # 256x4x4
        ])
        
        # Modularize upconvolutions
        self.upconvs = nn.ModuleList([
            nn.ConvTranspose2d(in_channels=256, out_channels=160, kernel_size=2, stride=2, padding=1, output_padding=1),
            nn.Conv2d(in_channels=160, out_channels=128, kernel_size=1, stride=1),
            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2, padding=1, output_padding=1),
            nn.ConvTranspose2d(in_channels=64, out_channels=48, kernel_size=2, stride=2, padding=1, output_padding=1),
            nn.ConvTranspose2d(in_channels=48, out_channels=24, kernel_size=2, stride=2)
        ])
        
        # Modularize decoders
        self.decoders = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(160*2, 160, kernel_size=3, padding=1),
                nn.BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True),
                nn.SiLU(inplace=True),
                nn.Dropout(0.3, inplace=True)
            ),
            nn.Sequential(
                nn.Conv2d(128*2, 128, kernel_size=3, padding=1),
                nn.BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True),
                nn.SiLU(inplace=True),
                nn.Dropout(0.3, inplace=True)
            ),
            nn.Sequential(
                nn.Conv2d(64*2, 64, kernel_size=3, padding=1),
                nn.BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True),
                nn.SiLU(inplace=True),
                nn.Dropout(0.2, inplace=True)
            ),
            nn.Sequential(
                nn.Conv2d(48*2, 48, kernel_size=3, padding=1),
                nn.BatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True),
                nn.SiLU(inplace=True),
                nn.Dropout(0.2, inplace=True)
            ),
            nn.Sequential(
                nn.Conv2d(24*2, 24, kernel_size=3, padding=1),
                nn.BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True),
                nn.SiLU(inplace=True),
                nn.Dropout(0.1, inplace=True)
            )
        ])

        # Optional, merge layers to increase the number of channels
        self.merge_layers = nn.Sequential(
            nn.Conv2d(24, output_channel_size, kernel_size=1),
            nn.BatchNorm2d(output_channel_size),
            nn.SiLU(),
            nn.Dropout(0.3)
        )
        
        self.global_avg_pool = nn.AdaptiveAvgPool2d(output_feature_size) # to reduce noise and overfitting

    def forward(self, x):
        # Encoder
        features = []
        for encoder in self.encoders:
            x = encoder(x)
            features.append(x)
        
        # Decoder
        x = features.pop()
        for upconv, decoder, feature in zip(self.upconvs, self.decoders, reversed(features)):
            x = upconv(x)
            x = torch.cat((x, feature), dim=1)
            x = decoder(x)
        
        x = self.merge_layers(x) # Introduced to increase the number of channels
        pooled_features = self.global_avg_pool(x) # Introduced to reduce noise and overfitting
        
        return pooled_features

model = efficientDecoder_v2_s()

# Print total number of parameters
total_params = sum(p.numel() for p in model.parameters())
print("Total number of parameters: ", total_params)

# Example initialization and forward pass
# model = EfficientDecoder_v2_s()
dummy_input = torch.randn(10, 3, 100, 100)  # Adjust size according to your actual input
output = model(dummy_input)
print(f"Output shape: {output.shape}")

In [1]:
input_dim=(128,12,12)

input_channels, input_height, input_width = input_dim

print(input_channels, input_height, input_width)

128 12 12


In [9]:
import torch
from src.models.baseline_models import construct_baselineModel, construct_baseModel, BaseModel
from src.models.backbone_models import densetFPN_121, densetFPN_201

# model = densetFPN_121()
# Create the model instance
# model = BaseModel(backbone=densetFPN_121, weights='DEFAULT', input_dim=(256,12,12))
model = construct_baseModel(backbone_name='densetFPN_121', weights='DEFAULT', input_dim=(256,12,12))

# Print total number of parameters
total_params = sum(p.numel() for p in model.parameters())
print("Total number of parameters: ", total_params)

# Create a dummy input tensor of size [50, 3, 100, 100]
dummy_input = torch.randn(50, 3, 100, 100)

# Forward pass through the model with dummy input
features = model(dummy_input)

# Print output shapes to verify
print("Features shape:", features.shape)

Total number of parameters:  45460865
torch.Size([50, 3, 100, 100])
torch.Size([50, 256, 12, 12])
Features shape: torch.Size([50, 1])


In [2]:
features

(tensor([[0.5106],
         [0.5035],
         [0.5488],
         [0.7421],
         [0.5720],
         [0.3836],
         [0.5911],
         [0.4598],
         [0.5526],
         [0.5835],
         [0.4567],
         [0.5489],
         [0.5642],
         [0.4666],
         [0.4395],
         [0.6266],
         [0.4210],
         [0.5086],
         [0.5422],
         [0.5515],
         [0.5906],
         [0.6073],
         [0.4439],
         [0.5127],
         [0.3540],
         [0.6443],
         [0.3925],
         [0.5958],
         [0.4544],
         [0.4678],
         [0.4648],
         [0.5237],
         [0.6828],
         [0.5871],
         [0.5303],
         [0.5280],
         [0.5333],
         [0.6036],
         [0.6479],
         [0.5992],
         [0.5446],
         [0.5203],
         [0.4284],
         [0.5404],
         [0.3937],
         [0.5419],
         [0.5040],
         [0.5163],
         [0.6542],
         [0.6807]], grad_fn=<SigmoidBackward0>),
 [tensor([[0.3055],


In [None]:
# Define the dictionary outside the class
MODEL_DICT = {
    'densetFPN_121': densetFPN_121,
    'densetFPN_201': densetFPN_201,
    'efficientFPN_v2_s': efficientFPN_v2_s,
    'efficientDecoder_v2_s': efficientDecoder_v2_s
}

class BaselineModel(nn.Module):
    def __init__(self, backbone, num_tasks=5, feature_dim=(256, 25, 25)):
        super(BaselineModel, self).__init__()
        self.backbone = backbone()  # Instantiate the backbone passed as a class
        
        feature_channels, feature_width, feature_height = feature_dim
        
        self.adaptive_pool = nn.AdaptiveAvgPool2d((feature_width, feature_height))
        
        self.task_specific_layers = nn.ModuleList([
            nn.Sequential(
                nn.Flatten(),
                nn.Linear(feature_channels * feature_width * feature_height, 1024),
                nn.BatchNorm1d(1024),
                nn.ReLU(),
                nn.Dropout(0.2)
            ) for _ in range(num_tasks)
        ])
        
        self.task_specific_classifier = nn.ModuleList([
            nn.Linear(1024, 1) for _ in range(num_tasks)
        ])
        
        self.final_classifier = nn.Sequential(
            nn.Linear(1024 * num_tasks, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(1024, 1)
        )
        
    def forward(self, x):
        x = self.backbone(x)
        x = self.adaptive_pool(x)
        
        intermediate_outputs = [layer(x) for layer in self.task_specific_layers]
        concatenated_outputs = torch.cat(intermediate_outputs, dim=1)
        
        task_outputs = [torch.sigmoid(classifier(io)) for io, classifier in zip(intermediate_outputs, self.task_specific_classifier)]
        
        final_output = torch.sigmoid(self.final_classifier(concatenated_outputs))
        
        return final_output, task_outputs

def construct_baselineModel(model_name, num_tasks=5, feature_dim=(256, 25, 25)):
    if model_name not in MODEL_DICT:
        raise ValueError(f"Unsupported model name {model_name}")
    backbone = MODEL_DICT[model_name]
    return BaselineModel(backbone, num_tasks, feature_dim)

In [5]:
import torch
from torchvision import models
from src.models.backbone_models import denseFPN_121, denseFPN_201

# Create the model instance
model = models.efficientnet_v2_s(weights=None).features[:-1]

print(model)

denseFPN_121(
  (encoder): ModuleList(
    (0): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): _DenseBlock(
        (denselayer1): _DenseLayer(
          (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (denselayer2): _DenseLayer(
          (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_runni

In [None]:
import torch
from src.models.backbone_models import denseFPN_121, denseFPN_201

# Create the model instance
model = denseFPN_121()

# Print total number of parameters
total_params = sum(p.numel() for p in model.parameters())
print("Total number of parameters: ", total_params)

# Create a dummy input tensor of size [50, 3, 100, 100]
dummy_input = torch.randn(50, 3, 100, 100)

# Forward pass through the model with dummy input
features = model(dummy_input)

# Print output shapes to verify
print("Features shape:", features.shape)

print("Output channels:", model.get_output_channels())

  warn(


Total number of parameters:  7643776
Features shape: torch.Size([50, 256, 12, 12])


KeyError: -1

## 1) Testing PPNet Model

In [17]:
import time
import torch

from helpers import list_of_distances, make_one_hot

def _train_or_test(model, dataloader, optimizer=None, class_specific=True, use_l1_mask=True,
                   coefs=None, log=print):
    '''
    model: the multi-gpu model
    dataloader:
    optimizer: if None, will be test evaluation
    '''
    is_train = optimizer is not None
    start = time.time()
    n_examples = 0
    n_correct = 0
    n_batches = 0
    total_cross_entropy = 0
    total_cluster_cost = 0
    # separation cost is meaningful only for class_specific
    total_separation_cost = 0
    total_avg_separation_cost = 0

    for i, (image, label) in enumerate(dataloader):
        input = image.cuda()
        target = label.cuda()

        # torch.enable_grad() has no effect outside of no_grad()
        grad_req = torch.enable_grad() if is_train else torch.no_grad()
        with grad_req:
            # nn.Module has implemented __call__() function
            # so no need to call .forward
            output, min_distances = model(input)

            # compute loss
            cross_entropy = torch.nn.functional.cross_entropy(output, target)

            if class_specific:
                max_dist = (model.module.prototype_shape[1]
                            * model.module.prototype_shape[2]
                            * model.module.prototype_shape[3])

                # prototypes_of_correct_class is a tensor of shape batch_size * num_prototypes
                # calculate cluster cost
                prototypes_of_correct_class = torch.t(model.module.prototype_class_identity[:,label]).cuda()
                inverted_distances, _ = torch.max((max_dist - min_distances) * prototypes_of_correct_class, dim=1)
                cluster_cost = torch.mean(max_dist - inverted_distances)

                # calculate separation cost
                prototypes_of_wrong_class = 1 - prototypes_of_correct_class
                inverted_distances_to_nontarget_prototypes, _ = \
                    torch.max((max_dist - min_distances) * prototypes_of_wrong_class, dim=1)
                separation_cost = torch.mean(max_dist - inverted_distances_to_nontarget_prototypes)

                # calculate avg cluster cost
                avg_separation_cost = \
                    torch.sum(min_distances * prototypes_of_wrong_class, dim=1) / torch.sum(prototypes_of_wrong_class, dim=1)
                avg_separation_cost = torch.mean(avg_separation_cost)
                
                if use_l1_mask:
                    l1_mask = 1 - torch.t(model.module.prototype_class_identity).cuda()
                    l1 = (model.module.last_layer.weight * l1_mask).norm(p=1)
                else:
                    l1 = model.module.last_layer.weight.norm(p=1) 

            else:
                min_distance, _ = torch.min(min_distances, dim=1)
                cluster_cost = torch.mean(min_distance)
                l1 = model.module.last_layer.weight.norm(p=1)

            # evaluation statistics
            _, predicted = torch.max(output.data, 1)
            n_examples += target.size(0)
            n_correct += (predicted == target).sum().item()

            n_batches += 1
            total_cross_entropy += cross_entropy.item()
            total_cluster_cost += cluster_cost.item()
            total_separation_cost += separation_cost.item()
            total_avg_separation_cost += avg_separation_cost.item()

        # compute gradient and do SGD step
        if is_train:
            if class_specific:
                if coefs is not None:
                    loss = (coefs['crs_ent'] * cross_entropy
                          + coefs['clst'] * cluster_cost
                          + coefs['sep'] * separation_cost
                          + coefs['l1'] * l1)
                else:
                    loss = cross_entropy + 0.8 * cluster_cost - 0.08 * separation_cost + 1e-4 * l1
            else:
                if coefs is not None:
                    loss = (coefs['crs_ent'] * cross_entropy
                          + coefs['clst'] * cluster_cost
                          + coefs['l1'] * l1)
                else:
                    loss = cross_entropy + 0.8 * cluster_cost + 1e-4 * l1
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        del input
        del target
        del output
        del predicted
        del min_distances

    end = time.time()

    log('\ttime: \t{0}'.format(end -  start))
    log('\tcross ent: \t{0}'.format(total_cross_entropy / n_batches))
    log('\tcluster: \t{0}'.format(total_cluster_cost / n_batches))
    if class_specific:
        log('\tseparation:\t{0}'.format(total_separation_cost / n_batches))
        log('\tavg separation:\t{0}'.format(total_avg_separation_cost / n_batches))
    log('\taccu: \t\t{0}%'.format(n_correct / n_examples * 100))
    log('\tl1: \t\t{0}'.format(model.module.last_layer.weight.norm(p=1).item()))
    p = model.module.prototype_vectors.view(model.module.num_prototypes, -1).cpu()
    with torch.no_grad():
        p_avg_pair_dist = torch.mean(list_of_distances(p, p))
    log('\tp dist pair: \t{0}'.format(p_avg_pair_dist.item()))

    return n_correct / n_examples


def train(model, dataloader, optimizer, class_specific=False, coefs=None, log=print):
    assert(optimizer is not None)
    
    log('\ttrain')
    model.train()
    return _train_or_test(model=model, dataloader=dataloader, optimizer=optimizer,
                          class_specific=class_specific, coefs=coefs, log=log)


def test(model, dataloader, class_specific=False, log=print):
    log('\ttest')
    model.eval()
    return _train_or_test(model=model, dataloader=dataloader, optimizer=None,
                          class_specific=class_specific, log=log)


def last_only(model, log=print):
    for p in model.module.features.parameters():
        p.requires_grad = False
    for p in model.module.add_on_layers.parameters():
        p.requires_grad = False
    model.module.prototype_vectors.requires_grad = False
    for p in model.module.last_layer.parameters():
        p.requires_grad = True
    
    log('\tlast layer')


def warm_only(model, log=print):
    for p in model.module.features.parameters():
        p.requires_grad = False
    for p in model.module.add_on_layers.parameters():
        p.requires_grad = True
    model.module.prototype_vectors.requires_grad = True
    for p in model.module.last_layer.parameters():
        p.requires_grad = True
    
    log('\twarm')


def joint(model, log=print):
    for p in model.module.features.parameters():
        p.requires_grad = True
    for p in model.module.add_on_layers.parameters():
        p.requires_grad = True
    model.module.prototype_vectors.requires_grad = True
    for p in model.module.last_layer.parameters():
        p.requires_grad = True
    
    log('\tjoint')

Total number of parameters:  15688838
Backbone output shape: torch.Size([50, 512, 1, 1])
Prototype vectors: 50
Unsqueeze Prototype 0 shape: torch.Size([1, 512, 1, 1]), task index: 0
Distances shape: torch.Size([50, 512, 1, 1])
Processed shape: torch.Size([50, 1024])
Unsqueeze Prototype 1 shape: torch.Size([1, 512, 1, 1]), task index: 0
Distances shape: torch.Size([50, 512, 1, 1])
Processed shape: torch.Size([50, 1024])
Unsqueeze Prototype 2 shape: torch.Size([1, 512, 1, 1]), task index: 0
Distances shape: torch.Size([50, 512, 1, 1])
Processed shape: torch.Size([50, 1024])
Unsqueeze Prototype 3 shape: torch.Size([1, 512, 1, 1]), task index: 0
Distances shape: torch.Size([50, 512, 1, 1])
Processed shape: torch.Size([50, 1024])
Unsqueeze Prototype 4 shape: torch.Size([1, 512, 1, 1]), task index: 0
Distances shape: torch.Size([50, 512, 1, 1])
Processed shape: torch.Size([50, 1024])
Unsqueeze Prototype 5 shape: torch.Size([1, 512, 1, 1]), task index: 0
Distances shape: torch.Size([50, 512, 

RuntimeError: mat1 and mat2 shapes cannot be multiplied (50x51200 and 5120x1024)

In [1]:
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from torch.optim import optimizer
import torch.nn.functional as F
from sklearn.metrics import balanced_accuracy_score, f1_score, precision_score, recall_score, roc_auc_score

def _train_or_test(model, data_loader, optimizer, device, is_train=True, use_l1_mask=True, coefs=None, task_weights=None):
    model.to(device)
    if is_train:
        model.train()
    else:
        model.eval()
    
    total_loss = 0
    
    total_correct = [0] * 5  # For tasks
    total_samples = [0] * 5  # For tasks
    
    task_cross_entropy = [0.0] * 5
    task_cluster_cost = [0.0] * 5
    task_separation_cost = [0.0] * 5
    task_avg_separation_cost = [0.0] * 5
    task_l1 = [0.0] * 5
    final_pred_targets = [[] for _ in range(5)]
    final_pred_outputs = [[] for _ in range(5)]
    
    final_correct = 0  # For final output
    final_samples = 0  # For final output
    final_targets = []  # For calculating balanced accuracy for final output
    final_outputs = []  # For calculating balanced accuracy for final output
    
    n_batches = 0
    context = torch.enable_grad() if is_train else torch.no_grad()
    with context:
        for X, targets, bweights_chars, final_target, bweight in tqdm(data_loader, leave=False):
            X = X.to(device)
            bweights_chars = [b.float().to(device) for b in bweights_chars]
            
            targets2 = [F.one_hot(t.squeeze(), num_classes=2).float().to(device) for t in targets]
            targets = [t.squeeze().to(device) for t in targets]
            final_target = final_target.float().unsqueeze(1).to(device)
            bweight = bweight.float().unsqueeze(1).to(device)
            
            final_output, task_outputs, min_distances = model(X)
            
            batch_loss = 0.0
            for i, (task_output, min_distance, target, bweight_char) in enumerate(zip(task_outputs, min_distances, targets, bweights_chars)):
                # Get the prototype identity for each characteristic
                prototype_char_identity = model.prototype_class_identity[i].to(device)
                
                # Get the max distance between prototypes
                max_dist = (model.prototype_shape[1] * model.prototype_shape[2] * model.prototype_shape[3])
                
                # Compute cross entropy cost for each characteristic
                cross_entropy = torch.nn.functional.cross_entropy(task_output, target, weight=bweight_char[0])

                # Compute cluster cost for each characteristic
                prototypes_of_correct_class = torch.t(prototype_char_identity[:,target]).to(device)    # batch_size * num_prototypes
                inverted_distances, _ = torch.max((max_dist - min_distance) * prototypes_of_correct_class, dim=1)
                cluster_cost = torch.mean(max_dist - inverted_distances) # Increase the distance between the prototypes of the same class

                # Compute separation cost for each characteristic
                prototypes_of_wrong_class = 1 - prototypes_of_correct_class
                inverted_distances_to_nontarget_prototypes, _ = torch.max((max_dist - min_distance) * prototypes_of_wrong_class, dim=1)
                separation_cost = torch.mean(max_dist - inverted_distances_to_nontarget_prototypes) # Decrease the distance between the prototypes of different classes

                # Compute average separation cost for each characteristic
                avg_separation_cost = torch.sum(min_distance * prototypes_of_wrong_class, dim=1) / torch.sum(prototypes_of_wrong_class, dim=1)
                avg_separation_cost = torch.mean(avg_separation_cost)
                
                # Compute l1 regularization for each characteristic
                if use_l1_mask:
                    l1_mask = 1 - torch.t(prototype_char_identity).to(device)
                    l1 = (model.task_specific_classifier[i].weight * l1_mask).norm(p=1)
                else:
                    l1 = model.task_specific_classifier[i].weight.norm(p=1) 
                    
                # Compute accuracy for each characteristic
                preds = task_output.argmax(dim=1)
                total_correct[i] += (preds == target).sum().item()
                total_samples[i] += target.size(0)
                
                # Collect data for balanced accuracy for each characteristic
                final_pred_targets[i].extend(target.cpu().numpy())
                final_pred_outputs[i].extend(preds.detach().cpu().numpy())

                task_cross_entropy[i] += cross_entropy.item()
                task_cluster_cost[i] += cluster_cost.item()
                task_separation_cost[i] += separation_cost.item()
                task_avg_separation_cost[i] += avg_separation_cost.item()
                task_l1[i] += l1
                
                batch_loss += (coefs['crs_ent'] * cross_entropy + 
                               coefs['clst'] * cluster_cost + 
                               coefs['sep'] * separation_cost +
                               coefs['l1'] * l1)

            # Compute binary cross entropy loss for final output
            final_loss = torch.nn.functional.binary_cross_entropy(final_output, final_target, weight=bweight)
            batch_loss += final_loss
            
            # Compute statistics for final accuracy
            final_preds = final_output.round()
            final_correct += (final_preds == final_target).sum().item()
            final_samples += final_target.size(0)
            final_targets.extend(final_target.cpu().numpy())
            final_outputs.extend(final_preds.detach().cpu().numpy())
            
            total_loss += batch_loss.item()  # Sum up total loss
            
            # compute gradient and do SGD step
            if is_train:
                optimizer.zero_grad()
                batch_loss.backward()
                optimizer.step()
                
            n_batches += 1
    
    # TODO: Add the seperate characteristic losses to the return dictionary, and include final F1 score, final precision, final recall, and final AUC
    average_loss = total_loss / n_batches
    task_accuracies = [correct / samples for correct, samples in zip(total_correct, total_samples)]
    task_balanced_accuracies = [balanced_accuracy_score(targets, outputs) for targets, outputs in zip(final_pred_targets, final_pred_outputs)]
    final_accuracy = final_correct / final_samples
    final_balanced_accuracy = balanced_accuracy_score(final_targets, final_outputs)
    # final_f1 = f1_score(final_targets, final_outputs)
    # final_precision = precision_score(final_targets, final_outputs)
    # final_recall = recall_score(final_targets, final_outputs)
    # final_auc = roc_auc_score(final_targets, final_outputs)
    
    # task_cross_entropy = [t / n_batches for t in task_cross_entropy]
    # task_cluster_cost = [t / n_batches for t in task_cluster_cost]
    # task_separation_cost = [t / n_batches for t in task_separation_cost]
    # task_avg_separation_cost = [t / n_batches for t in task_avg_separation_cost]
    # task_l1 = [t / n_batches for t in task_l1]
    
    # return the metrics as a dictionary
    metrics = {'average_loss': average_loss, 
               'task_accuracies': task_accuracies, 
               'task_balanced_accuracies': task_balanced_accuracies, 
               'final_accuracy': final_accuracy, 
               'final_balanced_accuracy': final_balanced_accuracy}
               # 'task_cross_entropy': task_cross_entropy,
               # 'task_cluster_cost': task_cluster_cost,
               # 'task_separation_cost': task_separation_cost,
               # 'task_avg_separation_cost': task_avg_separation_cost,
               # 'task_l1': task_l1}
    
    if is_train:
        return metrics
    else:
        return metrics

def train_ppnet(model, data_loader, optimizer, device, use_l1_mask=True, coefs=None, task_weights=None):
    train_metrics = _train_or_test(model, data_loader, optimizer, device, is_train=True, use_l1_mask=use_l1_mask, coefs=coefs, task_weights=task_weights)
    print(f"Train loss: {train_metrics['average_loss']:.5f}")
    for i, (acc, bal_acc) in enumerate(zip(train_metrics['task_accuracies'], train_metrics['task_balanced_accuracies']), 1):
        print(f"Task {i} - Train Accuracy: {acc*100:.2f}%, Train Balanced Accuracy: {bal_acc*100:.2f}%")
    # Print the metrics for the final output
    print(f"Final Output - Train Accuracy: {train_metrics['final_accuracy']*100:.2f}%, Train Balanced Accuracy: {train_metrics['final_balanced_accuracy']*100:.2f}%")
    return train_metrics

def test_ppnet(model, data_loader, device, use_l1_mask=True, coefs=None, task_weights=None):
    test_metrics = _train_or_test(model, data_loader, None, device, is_train=False, use_l1_mask=use_l1_mask, coefs=coefs, task_weights=task_weights)
    print(f"Test loss: {test_metrics['average_loss']:.5f}")
    for i, (acc, bal_acc) in enumerate(zip(test_metrics['task_accuracies'], test_metrics['task_balanced_accuracies']), 1):
        print(f"Task {i} - Test Accuracy: {acc*100:.2f}%, Test Balanced Accuracy: {bal_acc*100:.2f}%")
    # Print the metrics for the final output
    print(f"Final Output - Test Accuracy: {test_metrics['final_accuracy']*100:.2f}%, Test Balanced Accuracy: {test_metrics['final_balanced_accuracy']*100:.2f}%")
    return test_metrics
            
def last_only(model):
    for p in model.features.parameters():
        p.requires_grad = False
    for p in model.add_on_layers.parameters():
        p.requires_grad = False
    model.prototype_vectors.requires_grad = False
    for p in model.task_specific_classifier.parameters():
        p.requires_grad = True
    for p in model.final_classifier.parameters():
        p.requires_grad = True

def warm_only(model):
    for p in model.features.parameters():
        p.requires_grad = False
    for p in model.add_on_layers.parameters():
        p.requires_grad = True
    model.prototype_vectors.requires_grad = False
    for p in model.task_specific_classifier.parameters():
        p.requires_grad = True
    for p in model.final_classifier.parameters():
        p.requires_grad = True
        
def joint(model):
    for p in model.features.parameters():
        p.requires_grad = True
    for p in model.add_on_layers.parameters():
        p.requires_grad = True
    model.prototype_vectors.requires_grad = False
    for p in model.task_specific_classifier.parameters():
        p.requires_grad = True
    for p in model.final_classifier.parameters():
        p.requires_grad = True
    

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

from src.utils.receptive_field import compute_proto_layer_rf_info_v2
from src.models.backbone_models import denseFPN_121, denseFPN_201, efficientFPN_v2_s, efficientDecoder_v2_s, denseNet121, denseNet201

# Dictionary of supported backbone models
BACKBONE_DICT = {
    'denseNet121': denseNet121,
    'denseNet201': denseNet201,
    'efficientNetV2_s': models.efficientnet_v2_s(weights='DEFAULT').features[:-1],
    'denseFPN_121': denseFPN_121,
    'denseFPN_201': denseFPN_201,
    'efficientFPN_v2_s': efficientFPN_v2_s,
    'efficientDecoder_v2_s': efficientDecoder_v2_s
}

class PPNet(nn.Module):
    def __init__(self, features, img_size, prototype_shape, num_characteristics, proto_layer_rf_info=None, init_weights=True, prototype_activation_function='log', add_on_layers_type='bottleneck'):
        super(PPNet, self).__init__()
        # Define the input configurations
        self.img_size = img_size # size of the input images (e.g. (3, 224, 224))
        self.prototype_shape = prototype_shape # shape of the prototype vectors (e.g. (2000, 512, 1, 1))
        self.num_characteristics = num_characteristics # number of characteristics to predict (e.g. shape, margin, etc.)
        self.num_classes = 2 # binary classification
        
        self.num_prototypes = self.prototype_shape[0] # total number of prototypes
        self.prototypes_per_characteristic = self.num_prototypes // self.num_characteristics # number of prototypes per characteristic
        self.prototypes_per_class = self.prototypes_per_characteristic // self.num_classes # number of prototypes per class
        
        self.proto_layer_rf_info = proto_layer_rf_info
        self.epsilon = 1e-4 # small value to avoid numerical instability

        self.prototype_activation_function = prototype_activation_function # activation function for the prototypes
        
        self.prototype_class_identity = self._get_prototype_class_identity() # class identity of the prototypes
        
        # Define the feature extractor
        self.features = features
        
        # Define the add-on layers
        first_add_on_layer_in_channels = features.get_output_channels()
        # self.add_on_layers = self.initialize_add_on_layers(first_add_on_layer_in_channels, add_on_layers_type)
        if add_on_layers_type == 'bottleneck':
            add_on_layers = []
            current_in_channels = first_add_on_layer_in_channels
            while (current_in_channels > self.prototype_shape[1]) or (len(add_on_layers) == 0):
                current_out_channels = max(self.prototype_shape[1], (current_in_channels // 2))
                add_on_layers.append(nn.Conv2d(in_channels=current_in_channels,
                                               out_channels=current_out_channels,
                                               kernel_size=1))
                add_on_layers.append(nn.ReLU())
                add_on_layers.append(nn.Conv2d(in_channels=current_out_channels,
                                               out_channels=current_out_channels,
                                               kernel_size=1))
                if current_out_channels > self.prototype_shape[1]:
                    add_on_layers.append(nn.ReLU())
                else:
                    assert(current_out_channels == self.prototype_shape[1])
                    add_on_layers.append(nn.Sigmoid())
                current_in_channels = current_in_channels // 2
            self.add_on_layers = nn.Sequential(*add_on_layers)
        else:
            self.add_on_layers = nn.Sequential(
                nn.Conv2d(in_channels=first_add_on_layer_in_channels, out_channels=self.prototype_shape[1], kernel_size=1),
                nn.BatchNorm2d(self.prototype_shape[1]),
                nn.ReLU(),
                nn.Dropout(0.2),
                nn.Conv2d(in_channels=self.prototype_shape[1], out_channels=self.prototype_shape[1], kernel_size=1),
                nn.Sigmoid()
                )
        
        # Define separate prototype vectors for each characteristic
        self.prototype_vectors = nn.ParameterList([
            nn.Parameter(torch.rand(self.prototypes_per_characteristic, prototype_shape[1], prototype_shape[2], prototype_shape[3]), requires_grad=True) for _ in range(self.num_characteristics)
        ])
        
        # Define a tensor of ones for the l2-convolution
        self.ones = nn.Parameter(torch.ones(self.prototypes_per_characteristic, prototype_shape[1], prototype_shape[2], prototype_shape[3]), requires_grad=False)

        # Define a separate classifier for each characteristic
        self.task_specific_classifier = nn.ModuleList([
            nn.Linear(self.prototypes_per_characteristic, self.num_classes) for _ in range(self.num_characteristics)   # Apply softmax to get confidence scores for each class of each characteristic
        ])
        
        self.final_classifier = nn.Sequential(
            nn.Flatten(),
            # nn.Linear(self.num_characteristics*self.prototypes_per_characteristic*12*12, self.num_characteristics*self.num_classes), # HxW is the output size of the feature extractor
            # nn.BatchNorm1d(self.num_characteristics*self.num_classes),
            # nn.ReLU(),
            # nn.Dropout(0.2),
            nn.Linear(self.num_characteristics*self.prototypes_per_characteristic, 1)
        )

        if init_weights:
            self._initialize_weights()
            self._set_last_layer_incorrect_connection(-0.5)
    
    def _get_prototype_class_identity(self):
        """
        Initialize the class identities of the prototypes structured by characteristics.
        Each characteristic has a tensor of size [num_prototypes_per_characteristic, num_classes].
        """
        prototype_class_identity = []
        num_prototypes_per_class = self.prototypes_per_characteristic // self.num_classes
        
        # Create a separate class identity matrix for each characteristic
        for _ in range(self.num_characteristics):
            # Initialize a zero matrix for current characteristic
            class_identity = torch.zeros(self.prototypes_per_characteristic, self.num_classes)
            
            # Assign prototypes to each class (binary: two classes per characteristic)
            for j in range(self.prototypes_per_characteristic):
                class_index = j // num_prototypes_per_class
                class_identity[j, class_index] = 1
            
            prototype_class_identity.append(class_identity)
        
        return prototype_class_identity
    
    def _set_last_layer_incorrect_connection(self, incorrect_strength):
        '''
        the incorrect strength will be actual strength if -0.5 then input -0.5
        '''
        for i in range(self.num_characteristics):
            positive_one_weights_locations = torch.t(self.prototype_class_identity[i])
            negative_one_weights_locations = 1 - positive_one_weights_locations
            
            correct_class_connection = 1
            incorrect_class_connection = incorrect_strength
            self.task_specific_classifier[i].weight.data.copy_(
                correct_class_connection * positive_one_weights_locations
                + incorrect_class_connection * negative_one_weights_locations)

    def _initialize_weights(self):
        for m in self.add_on_layers.modules():
            if isinstance(m, nn.Conv2d):
                # every init technique has an underscore _ in the name
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def _l2_convolution(self, x, prototype_vector):
        '''
        apply prototype_vector as l2-convolution filters on input x
        '''
        x2 = x ** 2
        x2_patch_sum = F.conv2d(input=x2, weight=self.ones)

        p2 = prototype_vector ** 2
        p2 = torch.sum(p2, dim=(1, 2, 3))
        # p2 is a vector of shape (num_prototypes_per_characteristic,)
        # then we reshape it to (num_prototypes_per_characteristic, 1, 1)
        p2_reshape = p2.view(-1, 1, 1)

        xp = F.conv2d(input=x, weight=prototype_vector)
        intermediate_result = - 2 * xp + p2_reshape  # use broadcast
        # x2_patch_sum and intermediate_result are of the same shape
        distances = F.relu(x2_patch_sum + intermediate_result)

        return distances

    def distance_2_similarity(self, distances):
        if self.prototype_activation_function == 'log':
            return torch.log((distances + 1) / (distances + self.epsilon))
        elif self.prototype_activation_function == 'linear':
            return -distances
        else:
            return self.prototype_activation_function(distances)
        
    def forward(self, x):
        # Extract features using the backbone
        x = self.features(x) # B x 256 x H x W
        
        # Apply add-on layers to the features
        x = self.add_on_layers(x) # B x 512 x H x W
        
        # Compute distances and task logits for each characteristic
        task_logits = []
        task_probabilities = []
        similarities = []
        min_distances = []
        for i in range(self.num_characteristics):
            distance = self._l2_convolution(x, self.prototype_vectors[i]) # B x num_prototypes_per_characteristic x H x W, B x num_prototypes_per_characteristic x 1 x 1
            min_distance = -F.max_pool2d(-distance, kernel_size=(distance.size()[2], distance.size()[3])) # B x num_prototypes_per_characteristic x 1 x 1
            min_distance = min_distance.view(-1, self.prototypes_per_characteristic) # B x num_prototypes_per_characteristic
            similarity = self.distance_2_similarity(min_distance) # B x num_prototypes_per_characteristic
            task_logit = self.task_specific_classifier[i](similarity) # B x 2
            task_probability = F.softmax(task_logit, dim=1)
                        
            similarities.append(similarity)
            min_distances.append(min_distance)
            task_logits.append(task_logit)
            task_probabilities.append(task_probability)
        
        # Concatenate task distances for the final classifier
        # final_output = torch.sigmoid(self.final_classifier(torch.cat(distances, dim=1))) # TODO: Use intermediate task logits instead of distances and feature extractor output
        print(torch.cat(similarities, dim=1).shape)
        final_output = torch.sigmoid(self.final_classifier(torch.cat(similarities, dim=1)))
        return final_output, task_logits, min_distances
                    
    def push_forward(self, x):
        '''this method is needed for the pushing operation'''
        x = self.features(x)
        x = self.add_on_layers(x)
        distances = []
        for i in range(self.num_characteristics):
            distance = self._l2_convolution(x, self.prototype_vectors[i])
            distances.append(distance)
        return x, distances
    
    
    # TODO: Implement the pruning operation


def construct_PPNet(
    base_architecture='denseNet121', 
    weights='DEFAULT', 
    img_size=224,
    prototype_shape=(50*5*2, 224, 1, 1), 
    num_characteristics=5,
    prototype_activation_function='log',
    add_on_layers_type='bottleneck'
):
    
    features = BACKBONE_DICT[base_architecture](weights=weights)
    
    layer_filter_sizes, layer_strides, layer_paddings = features.conv_info()
    
    proto_layer_rf_info = compute_proto_layer_rf_info_v2(
        img_size=img_size,  
        layer_filter_sizes=layer_filter_sizes,
        layer_strides=layer_strides,
        layer_paddings=layer_paddings,
        prototype_kernel_size=prototype_shape[2]
    )
    
    return PPNet(
        features=features,
        img_size=img_size,
        prototype_shape=prototype_shape,
        num_characteristics=num_characteristics,
        proto_layer_rf_info=proto_layer_rf_info,
        init_weights=True,
        prototype_activation_function=prototype_activation_function,
        add_on_layers_type=add_on_layers_type
    )
    
# Set torch seed for reproducibility
torch.manual_seed(42)

# model = PPNet(features=denseFPN_121(), img_size=(3, 100, 100), prototype_shape=(10*5*2, 224, 1, 1), num_characteristics=5, init_weights=True, prototype_activation_function='log', add_on_layers_type='bottleneck')

model = construct_PPNet(base_architecture='denseFPN_121', weights='DEFAULT', img_size=100, prototype_shape=(10*5*2, 224, 1, 1), num_characteristics=5, prototype_activation_function='log', add_on_layers_type='bottleneck')

# Print total number of parameters
total_params = sum(p.numel() for p in model.parameters())
print("Total number of parameters: ", total_params)

# Create a dummy input tensor of size [50, 3, 100, 100]
dummy_input = torch.randn(10, 3, 100, 100)

# Forward pass through the model with dummy input
final_output, task_logits, min_distances = model(dummy_input)

# Print output shapes to verify
print("task_logits shape:", [task_logit.shape for task_logit in task_logits])  
print("final_output shape:", final_output.shape)

Total number of parameters:  7778845
torch.Size([10, 20])
torch.Size([10, 20])
torch.Size([10, 20])
torch.Size([10, 20])
torch.Size([10, 20])
torch.Size([10, 100])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (10x100 and 10x1)

In [1]:
from src.loaders.dataloader import LIDCDataset
import torchvision.transforms as transforms
from PIL import Image
import os
import torch

IMG_CHANNELS = 3
IMG_SIZE = 100
CHOSEN_CHARS = [False, True, False, True, True, False, False, True, True]
labels_file = './dataset/Meta/meta_info_old.csv'

mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)
preprocess = transforms.Compose([
    transforms.Grayscale(num_output_channels=3), 
    transforms.Resize(256),  # First resize to larger dimensions
    transforms.CenterCrop(224),  # Then crop to 224x224
    transforms.ToTensor(),  # Convert to tensor (also scales pixel values to [0, 1])
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# train set
LIDC_trainset = LIDCDataset(labels_file=labels_file, chosen_chars=CHOSEN_CHARS, auto_split=True, zero_indexed=False, 
                                                        transform=transforms.Compose([transforms.Grayscale(num_output_channels=IMG_CHANNELS), 
                                                                    transforms.Resize(size=(IMG_SIZE, IMG_SIZE), interpolation=Image.BILINEAR), 
                                                                    transforms.ToTensor(), 
                                                                    # transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
                                                                    transforms.Normalize(mean, std)
                                                                    ]),
                                                        train=True)
train_dataloader = torch.utils.data.DataLoader(LIDC_trainset, batch_size=50, shuffle=True, num_workers=4)
# test set
LIDC_testset = LIDCDataset(labels_file=labels_file, chosen_chars=CHOSEN_CHARS, auto_split=True, zero_indexed=False, 
                                                        transform=transforms.Compose([transforms.Grayscale(num_output_channels=IMG_CHANNELS), 
                                                                    transforms.Resize(size=(IMG_SIZE, IMG_SIZE), interpolation=Image.BILINEAR), 
                                                                    transforms.ToTensor(), 
                                                                    # transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
                                                                    transforms.Normalize(mean, std)
                                                                    ]), 
                                                        train=False)
test_dataloader = torch.utils.data.DataLoader(LIDC_testset, batch_size=50, shuffle=True, num_workers=4)


  warn(


In [2]:
from src.training.train_ppnet import train_ppnet, test_ppnet
from src.models.ProtoPNet import construct_PPNet
import torch

epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = construct_PPNet(base_architecture='denseFPN_121', weights='DEFAULT', img_size=100, prototype_shape=(10*5*2, 224, 1, 1), num_characteristics=5, prototype_activation_function='log', add_on_layers_type='bottleneck')

# Define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
coefs = {
    'crs_ent': 1,#0.4,#1.1,#0.8,#1.1,#1,#0.5,#changed from 1 at 48
    'clst': 0.8*1.5,#0.2,#0.3,#1.1,#0.8,#5,#0.8,
    'sep': -0.0004,#-0.0004,#-0.17,#-0.22,#-0.5, #used to be -0.08 #dm made it smaller to avoid the problem I noticed where as the separation loss is subtracted, having it too large makes your loss negative making everything explode (also -0.025 works but unstable)
    'l1': 1e-4
}

task_weights = [1.0 / 5] * 5
for epoch in range(epochs):
    print("\n" + "-"*100 + f"\nEpoch: {epoch + 1}/{epochs},\t" + f"Task Weights: {[f'{weight:.2f}' for weight in task_weights]}\n" + "-"*100)
    train_metrics,task_weights = train_ppnet(model, train_dataloader, optimizer, device, coefs=coefs, task_weights=task_weights)
    test_metrics = test_ppnet(model, test_dataloader, device, coefs=coefs)



----------------------------------------------------------------------------------------------------
Epoch: 1/10,	Task Weights: ['0.20', '0.20', '0.20', '0.20', '0.20']
----------------------------------------------------------------------------------------------------




Train loss: 16.54768
Task 1 - Train Loss: 31.90156, Train Accuracy: 96.96%, Train Balanced Accuracy: 96.96%
Task 2 - Train Loss: 35.12752, Train Accuracy: 51.39%, Train Balanced Accuracy: 49.86%
Task 3 - Train Loss: 35.08484, Train Accuracy: 57.85%, Train Balanced Accuracy: 50.00%
Task 4 - Train Loss: 57.84328, Train Accuracy: 41.99%, Train Balanced Accuracy: 49.80%
Task 5 - Train Loss: 38.85134, Train Accuracy: 47.68%, Train Balanced Accuracy: 50.00%
Final Output - Train Accuracy: 55.57%, Train Balanced Accuracy: 49.86%, Train F1 Score: 30.97%


  _warn_prf(average, modifier, msg_start, len(result))


Test loss: 25.95394
Task 1 - Test Accuracy: 100.00%, Test Balanced Accuracy: 100.00%
Task 2 - Test Accuracy: 71.98%, Test Balanced Accuracy: 50.00%
Task 3 - Test Accuracy: 56.90%, Test Balanced Accuracy: 50.00%
Task 4 - Test Accuracy: 21.02%, Test Balanced Accuracy: 50.03%
Task 5 - Test Accuracy: 48.04%, Test Balanced Accuracy: 50.00%
Final Output - Test Accuracy: 63.07%, Test Balanced Accuracy: 50.00%

----------------------------------------------------------------------------------------------------
Epoch: 2/10,	Task Weights: ['0.18', '0.19', '0.19', '0.25', '0.20']
----------------------------------------------------------------------------------------------------


                                                 

Train loss: 3.85263
Task 1 - Train Loss: 22.59399, Train Accuracy: 100.00%, Train Balanced Accuracy: 100.00%
Task 2 - Train Loss: 20.73491, Train Accuracy: 73.47%, Train Balanced Accuracy: 50.00%
Task 3 - Train Loss: 18.54049, Train Accuracy: 57.85%, Train Balanced Accuracy: 50.00%
Task 4 - Train Loss: 14.95040, Train Accuracy: 25.07%, Train Balanced Accuracy: 50.48%
Task 5 - Train Loss: 15.31349, Train Accuracy: 47.99%, Train Balanced Accuracy: 50.26%
Final Output - Train Accuracy: 60.36%, Train Balanced Accuracy: 49.85%, Train F1 Score: 11.18%


                                               

Test loss: 9.61602
Task 1 - Test Accuracy: 100.00%, Test Balanced Accuracy: 100.00%
Task 2 - Test Accuracy: 71.98%, Test Balanced Accuracy: 50.00%
Task 3 - Test Accuracy: 56.90%, Test Balanced Accuracy: 50.00%
Task 4 - Test Accuracy: 34.15%, Test Balanced Accuracy: 51.29%
Task 5 - Test Accuracy: 48.23%, Test Balanced Accuracy: 50.17%
Final Output - Test Accuracy: 57.38%, Test Balanced Accuracy: 51.10%

----------------------------------------------------------------------------------------------------
Epoch: 3/10,	Task Weights: ['0.22', '0.21', '0.20', '0.18', '0.18']
----------------------------------------------------------------------------------------------------


                                                 

Train loss: 2.17127
Task 1 - Train Loss: 24.21631, Train Accuracy: 100.00%, Train Balanced Accuracy: 100.00%
Task 2 - Train Loss: 15.93818, Train Accuracy: 73.47%, Train Balanced Accuracy: 50.00%
Task 3 - Train Loss: 14.42712, Train Accuracy: 57.86%, Train Balanced Accuracy: 50.01%
Task 4 - Train Loss: 11.66857, Train Accuracy: 36.30%, Train Balanced Accuracy: 51.03%
Task 5 - Train Loss: 11.24044, Train Accuracy: 48.54%, Train Balanced Accuracy: 50.76%
Final Output - Train Accuracy: 49.89%, Train Balanced Accuracy: 50.61%, Train F1 Score: 44.75%


                                      

In [1]:
from src.models.backbone_models import denseNet121

model = denseNet121()

kernel, stride, padding = model.conv_info()

print(len(kernel), len(stride), len(padding))
print(kernel)
print(stride)
print(padding)

  warn(


120 120 120
[7, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3]
[2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[3, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0

In [8]:
from src.models.backbone_models import denseFPN_121, efficientFPN_v2_s

model = efficientFPN_v2_s()

kernel, stride, padding = model.conv_info()

print(len(kernel), len(stride), len(padding))
print(kernel)
print(stride)
print(padding)

181 181 181
[3, 3, 3, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 1, 3, 1, 1, 1, 1, 3, 1, 1, 1, 1, 3, 1, 1, 1, 1, 3, 1, 1, 1, 1, 3, 1, 1, 1, 1, 3, 1, 1, 1, 1, 3, 1, 1, 1, 1, 3, 1, 1, 1, 1, 3, 1, 1, 1, 1, 3, 1, 1, 1, 1, 3, 1, 1, 1, 1, 3, 1, 1, 1, 1, 3, 1, 1, 1, 1, 3, 1, 1, 1, 1, 3, 1, 1, 1, 1, 3, 1, 1, 1, 1, 3, 1, 1, 1, 1, 3, 1, 1, 1, 1, 3, 1, 1, 1, 1, 3, 1, 1, 1, 1, 3, 1, 1, 1, 1, 3, 1, 1, 1, 1, 3, 1, 1, 1, 1, 3, 1, 1, 1, 1, 3, 1, 1, 1, 1, 3, 1, 1, 1, 1, 3, 1, 1, 1, 1, 3, 1, 1, 1, 1, 3, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

In [76]:
from src.models.densenet_features import densenet121_features

model = densenet121_features()

kernel, stride, padding = model.conv_info()

print(kernel)
print(stride)
print(padding)

[7, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 2, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 2, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 2, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3]
[2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[3, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0

In [36]:
import torch.nn.functional as F

target = [[1,1,1,1,1, 1,1,1,1,1],[1, 0, 1, 0, 1, 0, 1, 0, 1, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]]
targets = torch.tensor(target)
# one_hot_targets = [F.make]
print(targets)
one_hot_targets = [F.one_hot(t.long().squeeze()) for t in targets]
print(one_hot_targets)

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
        [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]])
[tensor([[0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1]]), tensor([[0, 1],
        [1, 0],
        [0, 1],
        [1, 0],
        [0, 1],
        [1, 0],
        [0, 1],
        [1, 0],
        [0, 1],
        [1, 0]]), tensor([[1, 0],
        [0, 1],
        [1, 0],
        [0, 1],
        [1, 0],
        [0, 1],
        [1, 0],
        [0, 1],
        [1, 0],
        [0, 1]])]


## 2) Testing XPNet Implementation

In [3]:
from src.loaders.dataloaderv2 import LIDCDataset
import torchvision.transforms as transforms
from PIL import Image
import torch

labels_file = './dataset/Meta/meta_info_old.csv'
IMG_CHANNELS = 3
IMG_SIZE = 100
CHOSEN_CHARS = [False, False, False, True, True, False, False, True, True]

# labels_file = './dataset/Meta/meta_info_old.csv'
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)

# train set
LIDC_trainset = LIDCDataset(labels_file=labels_file, chosen_chars=CHOSEN_CHARS, auto_split=True, zero_indexed=False, 
                                                        transform=transforms.Compose([transforms.Grayscale(num_output_channels=IMG_CHANNELS), 
                                                                    transforms.Resize(size=(IMG_SIZE, IMG_SIZE), interpolation=Image.BILINEAR), 
                                                                    transforms.ToTensor(), 
                                                                    # transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
                                                                    transforms.Normalize(mean, std)
                                                                    ]),
                                                        train=True)
train_dataloader = torch.utils.data.DataLoader(LIDC_trainset, batch_size=10, shuffle=True, num_workers=4)

LIDC_testset = LIDCDataset(labels_file=labels_file, chosen_chars=CHOSEN_CHARS, auto_split=True, zero_indexed=False, 
                                                        transform=transforms.Compose([transforms.Grayscale(num_output_channels=IMG_CHANNELS), 
                                                                    transforms.Resize(size=(IMG_SIZE, IMG_SIZE), interpolation=Image.BILINEAR), 
                                                                    transforms.ToTensor(), 
                                                                    # transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
                                                                    transforms.Normalize(mean, std)
                                                                    ]), 
                                                        train=False)
test_dataloader = torch.utils.data.DataLoader(LIDC_testset, batch_size=10, shuffle=True, num_workers=4)

print(len(train_dataloader))
print(len(test_dataloader))

1192
200


In [6]:
import torch
import torch.nn as nn
from src.utils.receptive_field import compute_proto_layer_rf_info_v2
from src.models.ProtoPNet import PPNet, BACKBONE_DICT

class XProtoNet(PPNet):
    def __init__(self, **kwargs):
        super(XProtoNet, self).__init__(**kwargs)

        # self.cnn_backbone = self.features
        # del self.features
        # cnn_backbone_out_channels = self.cnn_backbone.get_output_channels()

        cnn_backbone_out_channels = self.features.get_output_channels()
        
        # feature extractor module
        # self.add_on_layers = torch.nn.Sequential(*list(self.add_on_layers.children())[:-1])
        self.add_on_layers_module = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(in_channels=cnn_backbone_out_channels, out_channels=self.prototype_shape[1], kernel_size=1),
                nn.BatchNorm2d(self.prototype_shape[1]),
                nn.ReLU(),
                nn.Dropout(0.2),
                nn.Conv2d(in_channels=self.prototype_shape[1], out_channels=self.prototype_shape[1], kernel_size=1),
                nn.Sigmoid()
            ) for _ in range(self.num_characteristics)
        ])
        # self._initialize_weights(self.add_on_layers)

        # Occurrence map module
        self.occurrence_module = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(
                    in_channels=cnn_backbone_out_channels,
                    out_channels=self.prototype_shape[1],
                    kernel_size=1,
                ),
                nn.ReLU(),
                nn.Conv2d(
                    in_channels=self.prototype_shape[1],
                    out_channels=self.prototype_shape[1] // 2,
                    kernel_size=1,
                ),
                nn.ReLU(),
                nn.Conv2d(
                    in_channels=self.prototype_shape[1] // 2,
                    out_channels=self.prototypes_per_characteristic,
                    kernel_size=1,
                    bias=False,
                ),
                # nn.Conv2d(in_channels=self.prototype_shape[1], out_channels=self.prototype_shape[0], kernel_size=1, bias=False),
            ) for _ in range(self.num_characteristics)
        ])
        # self._initialize_weights(self.occurrence_module)

        # Last classification layer, redefine to initialize randomly
        self.task_specific_classifier = nn.ModuleList([
            nn.Linear(self.prototypes_per_characteristic, self.num_classes, bias=False) for _ in range(self.num_characteristics)   # Apply softmax to get confidence scores for each class of each characteristic
        ])
        
        self.final_add_on_layers = nn.Sequential(
            nn.Conv2d(in_channels=cnn_backbone_out_channels, out_channels=self.prototype_shape[1], kernel_size=1),
            nn.Conv2d(in_channels=self.prototype_shape[1], out_channels=self.prototypes_per_characteristic*self.num_characteristics, kernel_size=1),
            nn.BatchNorm2d(self.prototypes_per_characteristic*self.num_characteristics),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.AdaptiveAvgPool2d(1)
        )
        
        self.final_classifier = nn.Sequential(
            # nn.flatten(),
            nn.Linear(2*self.prototypes_per_characteristic*self.num_characteristics, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 1)
        )
            
        
        self._set_last_layer_incorrect_connection(incorrect_strength=0)

        self.om_softmax = nn.Softmax(dim=-1)
        self.cosine_similarity = nn.CosineSimilarity(dim=2)

    def forward(self, x):
        # Feature Extractor Layer
        x = self.features(x)
        
        # Hierarchical Prototype Layer
        task_logits = []
        similarities = []
        occurrence_maps = []
        for i in range(self.num_characteristics):
            feature_map = self.add_on_layers_module[i](x).unsqueeze(1)  # shape (N, 1, 128, H, W)
            
            occurrence_map = self.get_occurence_map_absolute_val(x, i)  # shape (N, P, 1, H, W)
            
            features_extracted = (occurrence_map * feature_map).sum(dim=3).sum(dim=3)  # shape (N, P, 128)
            
            similarity = self.cosine_similarity(
                features_extracted, self.prototype_vectors[i].squeeze().unsqueeze(0)
            )  # shape (N, P)
            similarity = (similarity + 1) / 2.0  # normalizing to [0,1] for positive reasoning

            # classification layer
            task_logit = self.task_specific_classifier[i](similarity)
            
            occurrence_maps.append(occurrence_map)
            similarities.append(similarity)
            task_logits.append(task_logit)
        
        # Prepare similarity vector
        similarity_vector = torch.cat(similarities, dim=1)

        # Process through final add-on layers
        final_layer_feature_map = self.final_add_on_layers(x).squeeze()

        # Concatenate all features
        final_layer_input = torch.cat((similarity_vector, final_layer_feature_map), dim=1)
        
        # Final Classification Layer
        final_output = torch.sigmoid(self.final_classifier(final_layer_input))

        return final_output, task_logits, similarities, occurrence_maps

    def compute_occurence_map(self, x, characteristic_index):
        # Feature Extractor Layer
        x = self.features(x)
        occurrence_map = self.get_occurence_map_absolute_val(x, characteristic_index)  # shape (N, P, 1, H, W)
        return occurrence_map

    def get_occurence_map_softmaxed(self, x, characteristic_index):
        occurrence_map = self.occurrence_module[characteristic_index](x)  # shape (N, P, H, W)
        n, p, h, w = occurrence_map.shape
        occurrence_map = occurrence_map.reshape((n, p, -1))
        occurrence_map = self.om_softmax(occurrence_map).reshape((n, p, h, w)).unsqueeze(2)  # shape (N, P, 1, H, W)
        return occurrence_map

    def get_occurence_map_absolute_val(self, x, characteristic_index):
        occurrence_map = self.occurrence_module[characteristic_index](x)  # shape (N, P, H, W)
        occurrence_map = torch.abs(occurrence_map).unsqueeze(2)  # shape (N, P, 1, H, W)
        return occurrence_map
    
    def push_forward(self, x):
        """
        this method is needed for the pushing operation
        """
        # Feature Extractor Layer
        x = self.features(x)
        
        features_extracted_list = []
        inverted_similarity_list = []
        occurrence_map_list = []
        preds_list = []
        for characteristic_index in range(self.num_characteristics):
            feature_map = self.add_on_layers_module[characteristic_index](x).unsqueeze(1)  # shape (N, 1, 128, H, W)
            occurrence_map = self.get_occurence_map_absolute_val(x,characteristic_index)  # shape (N, P, 1, H, W)
            features_extracted = (occurrence_map * feature_map).sum(dim=3).sum(dim=3)  # shape (N, P, 128)

            # Prototype Layer
            similarity = self.cosine_similarity(
                features_extracted, self.prototype_vectors[characteristic_index].squeeze().unsqueeze(0)
            )  # shape (N, P)
            similarity = (similarity + 1) / 2.0  # normalizing to [0,1] for positive reasoning

            # classification layer
            logits = self.task_specific_classifier[characteristic_index](similarity)
            preds = logits.softmax(dim=1)
            
            features_extracted_list.append(features_extracted)
            inverted_similarity_list.append(1-similarity)
            occurrence_map_list.append(occurrence_map)
            preds_list.append(preds)

        # return features_extracted, 1 - similarity, occurrence_map, logits
        return features_extracted_list, inverted_similarity_list, occurrence_map_list, preds_list
    
def construct_XPNet(
    base_architecture,
    weights='DEFAULT',
    img_size=100,
    prototype_shape=(10*4*2, 128, 1, 1),
    num_characteristics=4,
    prototype_activation_function="log",
    add_on_layers_type="regular",
):
    features = BACKBONE_DICT[base_architecture](weights=weights)
    layer_filter_sizes, layer_strides, layer_paddings = features.conv_info()
    proto_layer_rf_info = compute_proto_layer_rf_info_v2(
        img_size=img_size,
        layer_filter_sizes=layer_filter_sizes,
        layer_strides=layer_strides,
        layer_paddings=layer_paddings,
        prototype_kernel_size=prototype_shape[2],
    )
    return XProtoNet(
        features=features,
        img_size=img_size,
        prototype_shape=prototype_shape,
        proto_layer_rf_info=proto_layer_rf_info,
        num_characteristics=num_characteristics,
        init_weights=True,
        prototype_activation_function=prototype_activation_function,
        add_on_layers_type=add_on_layers_type,
    )
# Set torch seed for reproducibility
torch.manual_seed(42)

# model = PPNet(features=denseFPN_121(), img_size=(3, 100, 100), prototype_shape=(10*5*2, 224, 1, 1), num_characteristics=5, init_weights=True, prototype_activation_function='log', add_on_layers_type='bottleneck')

model = construct_XProtoNet(base_architecture='denseFPN_121', weights='DEFAULT', img_size=100, prototype_shape=(10*4*2, 224, 1, 1), num_characteristics=4, prototype_activation_function='log', add_on_layers_type='bottleneck')

# Print total number of parameters
total_params = sum(p.numel() for p in model.parameters())
print("Total number of parameters: ", total_params)

# Create a dummy input tensor of size [50, 3, 100, 100]
dummy_input = torch.randn(50, 3, 100, 100)

# Forward pass through the model with dummy input
final_output, task_logits, similarities, occurance_maps = model(dummy_input)

# Print output shapes to verify
print("task_logits shape:", [task_logit.shape for task_logit in task_logits])  
print("final_output shape:", final_output.shape)
print("occurance_map shape:", [occurance_map.shape for occurance_map in occurance_maps])

Total number of parameters:  8665713
task_logits shape: [torch.Size([50, 2]), torch.Size([50, 2]), torch.Size([50, 2]), torch.Size([50, 2])]
final_output shape: torch.Size([50, 1])
occurance_map shape: [torch.Size([50, 20, 1, 12, 12]), torch.Size([50, 20, 1, 12, 12]), torch.Size([50, 20, 1, 12, 12]), torch.Size([50, 20, 1, 12, 12])]


In [4]:
import torch
from tqdm import tqdm

from src.loss.loss import (
    CeLoss,
    ClusterRoiFeat,
    SeparationRoiFeat,
    OrthogonalityLoss,
    L_norm,
    TransformLoss,
)

from sklearn.metrics import balanced_accuracy_score, f1_score, recall_score, roc_auc_score


def _adjust_weights(task_losses, exponent=2, target_sum=5):
    """
    Adjusts the weights based on the task losses, using the sum of all task losses for normalization.
    
    Args:
        task_losses (list): List of losses for each task.
        exponent (int): The exponent used for calculating the inverse weights. Defaults to 2.
        target_sum (int): The total sum to which the weights should scale. Defaults to 5.

    Returns:
        list: A list of adjusted weights for each task.
    """
    # Calculate the total sum of all task losses
    total_loss = sum(task_losses)
    # Normalize each loss by the total sum of losses
    normalized_losses = [loss / total_loss for loss in task_losses] if total_loss > 0 else [0] * len(task_losses)
    # Calculate weights using the normalized losses
    weights = [1.0 / ((1.0 - loss) ** exponent + 1e-6) for loss in normalized_losses]
    total_weight = sum(weights)
    scaled_weights = [w / total_weight * target_sum for w in weights]
    return scaled_weights

def _train_or_test(model, data_loader, optimizer, device, is_train=True, use_l1_mask=True, coefs=None, task_weights=None):
    model.to(device)
    if is_train:
        model.train()
    else:
        model.eval()
    
    num_characteristics = model.num_characteristics
    num_classes = model.num_classes
    
    # Initialize the loss functions
    CrossEntropy = CeLoss(loss_weight=coefs['crs_ent'])
    Cluster = ClusterRoiFeat(loss_weight=coefs['clst'], num_classes=num_classes)
    Separation = SeparationRoiFeat(loss_weight=coefs['sep'], num_classes=num_classes)
    # Orthogonality = OrthogonalityLoss(loss_weight=coefs['orth'], num_classes=num_classes)
    Transform = TransformLoss(loss_weight=coefs['trans'])
    L1_occ = L_norm(loss_weight=coefs['l1_occ'], mask=None, reduction="mean", p=1)
    
    # Initialize the task losses for each characteristic
    task_total_losses = [0.0] * num_characteristics
    task_cross_entropy = [0.0] * num_characteristics
    task_cluster_cost = [0.0] * num_characteristics
    task_separation_cost = [0.0] * num_characteristics
    task_l1 = [0.0] * num_characteristics
    task_occ_cost = [0.0] * num_characteristics
    task_targets_all = [[] for _ in range(num_characteristics)]
    task_predictions_all = [[] for _ in range(num_characteristics)]

    # Initialize the final output losses
    final_total_loss = 0.0
    final_targets_all = []
    final_predictions_all = []
    
    # Initialize the total loss
    total_loss = 0.0
    
    n_batches = 0
    context = torch.enable_grad() if is_train else torch.no_grad()
    with context:
        for X, targets, bweights_chars, final_target, bweight in tqdm(data_loader, leave=False):
            X = X.to(device)
            bweights_chars = [b.float().to(device) for b in bweights_chars]            
            targets = [t.squeeze().to(device) for t in targets]
            final_target = final_target.float().unsqueeze(1).to(device)
            bweight = bweight.float().unsqueeze(1).to(device)
            
            final_output, task_outputs, similarities, occurrence_maps = model(X)
            
            ############################ Compute Losses ############################
            
            batch_loss = 0.0
            for characteristic_idx, (task_output, similarity, occurrence_map, target, bweight_char) in enumerate(zip(task_outputs, similarities, occurrence_maps, targets, bweights_chars)):
                # Get the prototype identity for the current characteristic
                prototype_char_identity = model.prototype_class_identity[characteristic_idx].to(device)
                
                # Compute cross entropy cost - to encourage the correct classification of the input
                cross_entropy_cost = CrossEntropy.compute(task_output, target)
                
                # Compute cluster cost - to encourage similarity among prototypes of the same class
                cluster_cost = Cluster.compute(similarity, target)
                
                # Compute separation cost - to encourage diversity among prototypes of different classes
                separation_cost = Separation.compute(similarity, target)
                
                # TODO: Compute Orthogonality loss - to encourage diversity among prototypes
                
                # Compute l1 regularization on task-specific classifier weights - to encourage sparsity in the weights
                l1 = L_norm(loss_weight=coefs['l1'], mask=(1 - torch.t(prototype_char_identity).to(device))).compute(model.task_specific_classifier[characteristic_idx].weight)
                
                # Compute Occurance Map Transformation Regularization - to encourage the occurance map to be generalize better
                occ_trans_cost = Transform.compute(X, occurrence_map, model, characteristic_idx)
                
                # Compute Occurance Map L1 Regularization - to make the occruance map as small as possible to avoid covering more regions than necessary
                occ_l1 = L1_occ.compute(occurrence_map, dim=(-2, -1))
                
                occurance_map_cost = occ_trans_cost + occ_l1
                
                # Update the different task losses for each characteristic
                task_cross_entropy[characteristic_idx] += cross_entropy_cost.item()
                task_cluster_cost[characteristic_idx] += cluster_cost.item()
                task_separation_cost[characteristic_idx] += separation_cost.item()
                task_l1[characteristic_idx] += l1.item()
                task_occ_cost[characteristic_idx] += occurance_map_cost.item()
                
                # Collect the different losses for each characteristic
                task_loss = (
                    cross_entropy_cost 
                    + cluster_cost 
                    + separation_cost 
                    + l1
                    + occurance_map_cost
                )
                
                # Update the task total losses
                task_total_losses[characteristic_idx] += task_loss.item()
                
                # Apply task weights if provided
                if task_weights:
                    task_loss *= task_weights[characteristic_idx]
                
                # Update the total loss for the batch
                batch_loss += task_loss                
                
                # Collect statistics for each characteristic's prediction metrics
                preds = task_output.argmax(dim=1)
                task_targets_all[characteristic_idx].extend(target.cpu().numpy())
                task_predictions_all[characteristic_idx].extend(preds.detach().cpu().numpy())                

            # Compute binary cross entropy loss for final output
            final_loss = torch.nn.functional.binary_cross_entropy(final_output, final_target, weight=bweight)
            batch_loss += final_loss
            
            # Collect statistics for final prediction metrics
            final_total_loss += final_loss.item()
            final_preds = final_output.round()
            final_targets_all.extend(final_target.cpu().numpy())
            final_predictions_all.extend(final_preds.detach().cpu().numpy())
            
            total_loss += batch_loss.item()  # Sum up total loss
            
            # Compute gradient and do SGD step
            if is_train:
                optimizer.zero_grad()
                batch_loss.backward()
                optimizer.step()
                
            n_batches += 1
    
    ############################ Compute Metrics ############################
    
    average_loss = total_loss / n_batches
    
    task_losses = [t / n_batches for t in task_total_losses]
    task_cross_entropy = [t / n_batches for t in task_cross_entropy]
    task_cluster_cost = [t / n_batches for t in task_cluster_cost]
    task_separation_cost = [t / n_batches for t in task_separation_cost]
    task_l1 = [t / n_batches for t in task_l1]
    task_occ_cost = [t / n_batches for t in task_occ_cost]
    task_balanced_accuracies = [balanced_accuracy_score(targets, outputs) for targets, outputs in zip(task_targets_all, task_predictions_all)]
    
    final_loss = final_total_loss / n_batches
    final_balanced_accuracy = balanced_accuracy_score(final_targets_all, final_predictions_all)
    final_f1 = f1_score(final_targets_all, final_predictions_all)
    # final_precision = precision_score(final_targets_all, final_predictions_all)
    final_recall = recall_score(final_targets_all, final_predictions_all)
    final_auc = roc_auc_score(final_targets_all, final_predictions_all)

    # return the metrics as a dictionary
    metrics = {'average_loss': average_loss, 
               'task_losses': task_losses,
               'task_balanced_accuracies': task_balanced_accuracies, 
               'task_cross_entropy': task_cross_entropy,
               'task_cluster_cost': task_cluster_cost,
               'task_separation_cost': task_separation_cost,
               'task_l1': task_l1,
               'final_loss': final_loss,
               'final_balanced_accuracy': final_balanced_accuracy,
               'final_f1': final_f1,
               # 'final_precision': final_precision,
               'final_recall': final_recall,
               'final_auc': final_auc
            }
    
    if is_train:
        task_weights = _adjust_weights(task_losses, exponent=5, target_sum=4)
        return metrics, task_weights
    else:
        return metrics

def train_ppnet(model, data_loader, optimizer, device, use_l1_mask=True, coefs=None, task_weights=None):
    train_metrics, task_weights = _train_or_test(model, data_loader, optimizer, device, is_train=True, use_l1_mask=use_l1_mask, coefs=coefs, task_weights=task_weights)
    print("\nFinal Train Metrics:")
    print(f"Total Loss: {train_metrics['average_loss']:.5f}")
    for i, (bal_acc, task_loss, task_ce, task_cc, task_sc) in enumerate(zip(train_metrics['task_balanced_accuracies'], train_metrics['task_losses'], train_metrics['task_cross_entropy'], train_metrics['task_cluster_cost'], train_metrics['task_separation_cost']), 1):
        print(f"Characteristic {i}      - Task Loss: {task_loss:.2f}, Cross Entropy: {task_ce:.2f}, Cluster Cost: {task_cc:.2f}, Separation Cost: {task_sc:.2f}, Balanced Accuracy: {bal_acc*100:.2f}%")
    # Print the metrics for the final output
    print(f"Malignancy Prediction - Binary Cross Entropy Loss: {train_metrics['final_loss']:.2f}, Balanced Accuracy: {train_metrics['final_balanced_accuracy']*100:.2f}%, F1 Score: {train_metrics['final_f1']*100:.2f}%")
    return train_metrics, task_weights

def test_ppnet(model, data_loader, device, use_l1_mask=True, coefs=None, task_weights=None):
    test_metrics = _train_or_test(model, data_loader, None, device, is_train=False, use_l1_mask=use_l1_mask, coefs=coefs, task_weights=task_weights)
    print("\nFinal Test Metrics:")
    print(f"Total Loss: {test_metrics['average_loss']:.5f}")
    for i, (bal_acc, task_loss, task_ce, task_cc, task_sc) in enumerate(zip(test_metrics['task_balanced_accuracies'], test_metrics['task_losses'], test_metrics['task_cross_entropy'], test_metrics['task_cluster_cost'], test_metrics['task_separation_cost']), 1):
        print(f"Characteristic {i}      - Task Loss: {task_loss:.2f}, Cross Entropy: {task_ce:.2f}, Cluster Cost: {task_cc:.2f}, Separation Cost: {task_sc:.2f}, Balanced Accuracy: {bal_acc*100:.2f}%")
    # Print the metrics for the final output
    print(f"Malignancy Prediction - Binary Cross Entropy Loss: {test_metrics['final_loss']:.2f}, Balanced Accuracy: {test_metrics['final_balanced_accuracy']*100:.2f}%, F1 Score: {test_metrics['final_f1']*100:.2f}%")
    return test_metrics
            
def last_only(model):
    for p in model.features.parameters():
        p.requires_grad = False
    for p in model.add_on_layers.parameters():
        p.requires_grad = False
    model.prototype_vectors.requires_grad = False
    for p in model.task_specific_classifier.parameters():
        p.requires_grad = True
    for p in model.final_classifier.parameters():
        p.requires_grad = True # was true

def warm_only(model):
    if model.features.encoder is not None:
        for p in model.features.encoder.parameters():
            p.requires_grad = False
        for p in model.features.adaptation_layers.parameters():
            p.requires_grad = True
        for p in model.features.fpn.parameters():
            p.requires_grad = True
    else:
        for p in model.features.parameters():
            p.requires_grad = False
    for p in model.add_on_layers.parameters():
        p.requires_grad = True
    model.prototype_vectors.requires_grad = True
    for p in model.task_specific_classifier.parameters():
        p.requires_grad = False
    for p in model.final_classifier.parameters():
        p.requires_grad = False
        
def joint(model):
    for p in model.features.parameters():
        p.requires_grad = True
    for p in model.add_on_layers.parameters():
        p.requires_grad = True
    model.prototype_vectors.requires_grad = True
    for p in model.task_specific_classifier.parameters():
        p.requires_grad = False
    for p in model.final_classifier.parameters():
        p.requires_grad = False
    

In [5]:
# Define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
coefs = {
    'crs_ent': 1,#0.4,#1.1,#0.8,#1.1,#1,#0.5,#changed from 1 at 48
    'clst': 0.8*1.5,#0.2,#0.3,#1.1,#0.8,#5,#0.8,
    'sep': -0.0004,#-0.0004,#-0.17,#-0.22,#-0.5, #used to be -0.08 #dm made it smaller to avoid the problem I noticed where as the separation loss is subtracted, having it too large makes your loss negative making everything explode (also -0.025 works but unstable)
    'l1': 1e-4,
    'l1_occ': 1e-4,
    'trans' : 1e-4
}
epochs = 1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

task_weights = [1.0] * 4
for epoch in range(epochs):
    print("\n" + "-"*100 + f"\nEpoch: {epoch + 1}/{epochs},\t" + f"Task Weights: {[f'{weight:.2f}' for weight in task_weights]}\n" + "-"*100)
    train_metrics,task_weights = train_ppnet(model, train_dataloader, optimizer, device, coefs=coefs, task_weights=task_weights)
    test_metrics = test_ppnet(model, test_dataloader, device, coefs=coefs)

cuda

----------------------------------------------------------------------------------------------------
Epoch: 1/1,	Task Weights: ['1.00', '1.00', '1.00', '1.00']
----------------------------------------------------------------------------------------------------


                                                   

KeyboardInterrupt: 

In [20]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os
import time
from tqdm import tqdm

from src.utils.helpers import makedir, save_pickle

def push_prototypes(
    dataloader,  # pytorch dataloader
    model,  # pytorch network with feature encoder and prototype vectors
    device,
    class_specific=True,  # enable pushing protos from only the alotted class
    abstain_class=False,  # indicates K+1-th class is of the "abstain" type
    preprocess_input_function=None,  # normalize if needed
    root_dir_for_saving_prototypes=None,  # if not None, prototypes will be saved in this dir
    epoch_number=None,  # if not provided, prototypes saved previously will be overwritten
    prototype_img_filename_prefix=None,
    prototype_self_act_filename_prefix=None,
    proto_bound_boxes_filename_prefix=None,
    replace_prototypes=True
):
    """
    Search the training set for image patches that are semantically closest to
    each learned prototype, then updates the prototypes to those image patches.

    To do this, it computes the image patch embeddings (IPBs) and saves those
    closest to the prototypes. It also saves the prototype-to-IPB distances and
    predicted occurrence maps.

    If abstain_class==True, it assumes num_classes actually equals to K+1, where
    K is the number of real classes and 1 is the extra "abstain" class for
    uncertainty estimation.
    """

    model.eval()
    print(f"############## push at epoch {epoch_number} #################")

    # creating the folder (with epoch number) to save the prototypes' info and visualizations
    if root_dir_for_saving_prototypes != None:
        if epoch_number != None:
            proto_epoch_dir = os.path.join(root_dir_for_saving_prototypes, "epoch-" + str(epoch_number))
            makedir(proto_epoch_dir)
        else:
            proto_epoch_dir = root_dir_for_saving_prototypes
    else:
        proto_epoch_dir = None

    # find the number of prototypes, and number of classes for this push
    # prototype_shape = (model.prototypes_per_characteristic, model.prototype_shape[1], model.prototype_shape[2], model.prototype_shape[3])  # shape (P, D, 1, 1)
    P = model.prototypes_per_characteristic
    num_characteristics = model.num_characteristics
    num_classes = model.num_classes
    
    proto_class_specific = np.full(P, class_specific)
    
    if abstain_class:
        K = num_classes - 1
        assert K >= 2, "Abstention-push must have >= 2 classes not including abstain"
        # for the uncertainty prototypes, class_specific is False
        # for now assume that each class (inc. unc.) has P_per_class == P/num_classes
        P_per_class = P // num_classes
        proto_class_specific[K * P_per_class : P] = False
    else:
        K = num_classes

    # keep track of the input embedding closest to each prototype
    proto_dist_ = [np.full(P, np.inf) for _ in range(num_characteristics)]  # saves the distances to prototypes (distance = 1-CosineSimilarities). shape (P)
    # save some information dynamically for each prototype
    # which are updated whenever a closer match to prototype is found
    occurrence_map_ = [[None for _ in range(P)] for _ in range(num_characteristics)] # saves the computed occurence maps. shape (P, 1, H, W)
    # saves the input to prototypical layer (conv feature * occurrence map), shape (P, D)
    protoL_input_ = [[None for _ in range(P)] for _ in range(num_characteristics)]
    # saves the input images with embeddings closest to each prototype. shape (P, 3, Ho, Wo)
    image_ = [[None for _ in range(P)] for _ in range(num_characteristics)]
    # saves the gt label. shape (P)
    gt_ = [[None for _ in range(P)] for _ in range(num_characteristics)]
    # saves the prediction logits of cases seen. shape (P, K)
    pred_ = [[None for _ in range(P)] for _ in range(num_characteristics)]
    # saves the filenames of cases closest to each prototype. shape (P)
    # filename_ = [None for _ in range(P)] * num_characteristics # TODO: add filename in getitem of dataloader

    # data_iter = iter(dataloader)
    # iterator = tqdm(range(len(dataloader)), dynamic_ncols=True)
    for X, y, _, _, _ in tqdm(dataloader, leave=False):
        # data_sample = next(data_iter)
        # x = data_sample["cine"]  
        
        if preprocess_input_function is not None:
            X = preprocess_input_function(X)

        # get the network outputs for this instance
        with torch.no_grad():
            x = X.to(device)    # shape (B, 3, Ho, Wo)
            (
                protoL_input_torch,
                proto_dist_torch,
                occurrence_map_torch,
                pred_torch,
            ) = model.push_forward(x)
            # pred_torch = logits.softmax(dim=1)

        # gt = y.detach().cpu().numpy()  # shape (B)
        image = x.detach().cpu().numpy()  # shape (B, 3, Ho, Wo)
        # filename = data_sample["filename"]  # shape (B) 
        
        for characteristic_idx in range(num_characteristics):
            proto_class_identity = np.argmax(model.prototype_class_identity[characteristic_idx].cpu().numpy(), axis=1)  # shape (P)
            # record down batch data as numpy arrays
            gt = y[characteristic_idx].detach().cpu().numpy()
            protoL_input = protoL_input_torch[characteristic_idx].detach().cpu().numpy()
            proto_dist = proto_dist_torch[characteristic_idx].detach().cpu().numpy()
            occurrence_map = occurrence_map_torch[characteristic_idx].detach().cpu().numpy()
            pred = pred_torch[characteristic_idx].detach().cpu().numpy()

            # for each prototype, find the minimum distance and their indices
            for prototype_idx in range(P):
                proto_dist_j = proto_dist[:, prototype_idx]  # (B)
                if proto_class_specific[prototype_idx]:
                    # compare with only the images of the prototype's class
                    proto_dist_j = np.ma.masked_array(proto_dist_j, gt != proto_class_identity[prototype_idx])
                    if proto_dist_j.mask.all():
                        # if none of the classes this batch are the class of interest, move on
                        continue
                proto_dist_j_min = np.amin(proto_dist_j)  # scalar

                # if the distance this batch is smaller than prev.best, save it
                if proto_dist_j_min <= proto_dist_[characteristic_idx][prototype_idx]:
                    a = np.argmin(proto_dist_j)
                    
                    proto_dist_[characteristic_idx][prototype_idx] = proto_dist_j_min
                    protoL_input_[characteristic_idx][prototype_idx] = protoL_input[a, prototype_idx]
                    occurrence_map_[characteristic_idx][prototype_idx] = occurrence_map[a, prototype_idx]
                    pred_[characteristic_idx][prototype_idx] = pred[a]
                    image_[characteristic_idx][prototype_idx] = image[a]
                    gt_[characteristic_idx][prototype_idx] = gt[a]
                    # filename_[characteristic_idx][prototype_idx] = filename[a]

    prototypes_similarity_to_src_ROIs = 1 - np.array(proto_dist_)  # invert distance to similarity  shape (P)
    prototypes_occurrence_maps = np.array(occurrence_map_)  # shape (P, 1, H, W)
    prototypes_src_imgs = np.array(image_)  # shape (P, 3, Ho, Wo)
    prototypes_gts = np.array(gt_)  # shape (P)
    prototypes_preds = np.array(pred_)  # shape (P, K)
    # prototypes_filenames = np.array(filename_)  # shape (P)

    # save the prototype information in a pickle file
    prototype_data_dict = {
        # "prototypes_filenames": prototypes_filenames,
        "prototypes_src_imgs": prototypes_src_imgs,
        "prototypes_gts": prototypes_gts,
        "prototypes_preds": prototypes_preds,
        "prototypes_occurrence_maps": prototypes_occurrence_maps,
        "prototypes_similarity_to_src_ROIs": prototypes_similarity_to_src_ROIs,
    }
    save_pickle(prototype_data_dict, f"{proto_epoch_dir}/prototypes_info.pickle")

    if replace_prototypes:
        protoL_input_ = np.array(protoL_input_)
        print("\tExecuting push ...")
        
        for idx, (prototype_vectors, protoL_input_char) in enumerate(zip(model.prototype_vectors, protoL_input_)):
            prototype_update = np.reshape(protoL_input_char, prototype_vectors.shape)
            with torch.no_grad():
                prototype_vectors.data.copy_(torch.tensor(prototype_update, dtype=torch.float32).to(device))
            
        # prototype_update = np.reshape(protoL_input_, tuple(prototype_shape))
        # model.prototype_vectors.data.copy_(torch.tensor(prototype_update, dtype=torch.float32).cuda())


In [15]:
model.num_characteristics

4

In [21]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

push_prototypes(train_dataloader, 
                model, 
                device=device,
                class_specific=True, 
                preprocess_input_function=None, 
                root_dir_for_saving_prototypes=None,
                epoch_number=1,
                replace_prototypes=True
            )

############## push at epoch 1 #################


  0%|          | 0/1192 [00:00<?, ?it/s]

                                                   

FileNotFoundError: [Errno 2] No such file or directory: 'None/prototypes_info.pickle'

In [8]:
from src.models.backbone_models import denseNet121, denseNet169, resNet152, resNet34, vgg16, vgg19, denseFPN_121
import torch 
import torch.nn as nn

from efficientnet_pytorch_3d import EfficientNet3D
model = EfficientNet3D.from_name("efficientnet-b3", in_channels=1, override_params={'num_classes': 1})
# model = EfficientNet3D.from_name("efficientnet-b0", override_params={'num_classes': 1}, in_channels=1).extract_features
# model = nn.Sequential(*list(model.children())[:-1])

# print number of parameters
total_params = sum(p.numel() for p in model.parameters())
print("Total number of parameters: ", total_params)

dummy_input = torch.randn(1, 1, 128, 128, 128)

features = model(dummy_input)

print(features.shape)

Total number of parameters:  12060009
torch.Size([1, 1])


In [5]:
model

<bound method EfficientNet3D.extract_features of EfficientNet3D(
  (_conv_stem): Conv3dStaticSamePadding(
    1, 32, kernel_size=(3, 3, 3), stride=(2, 2, 2), bias=False
    (static_padding): ZeroPad2d((0, 1, 0, 1, 0, 1))
  )
  (_bn0): BatchNorm3d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
  (_blocks): ModuleList(
    (0): MBConvBlock3D(
      (_depthwise_conv): Conv3dStaticSamePadding(
        32, 32, kernel_size=(3, 3, 3), stride=[2, 2, 2], groups=32, bias=False
        (static_padding): ZeroPad2d((0, 1, 0, 1, 0, 1))
      )
      (_bn1): BatchNorm3d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
      (_se_reduce): Conv3dStaticSamePadding(
        32, 8, kernel_size=(1, 1, 1), stride=(1, 1, 1)
        (static_padding): Identity()
      )
      (_se_expand): Conv3dStaticSamePadding(
        8, 32, kernel_size=(1, 1, 1), stride=(1, 1, 1)
        (static_padding): Identity()
      )
      (_project_conv): Co

In [1]:
from src.models.backbone_models import denseNet121, denseNet169, resNet152, resNet34, vgg16, vgg19, denseFPN_121, EfficeientNet3d
import torch 

model = EfficientNet3d

# print number of parameters
# total_params = sum(p.numel() for p in model.parameters())
# print("Total number of parameters: ", total_params)

dummy_input = torch.randn(1, 1, 128, 128, 128)

features = model(dummy_input)

print(features)

  warn(


ImportError: cannot import name 'EfficientNet3d' from 'src.models.backbone_models' (c:\Users\jerem\Desktop\ICL-MSc\Final_Year_Project\FYP-interpretable-deep-learning\2D-model\src\models\backbone_models.py)

In [10]:
model

EfficientNet3D(
  (_conv_stem): Conv3dStaticSamePadding(
    1, 32, kernel_size=(3, 3, 3), stride=(2, 2, 2), bias=False
    (static_padding): ZeroPad2d((0, 1, 0, 1, 0, 1))
  )
  (_bn0): BatchNorm3d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
  (_blocks): ModuleList(
    (0): MBConvBlock3D(
      (_depthwise_conv): Conv3dStaticSamePadding(
        32, 32, kernel_size=(3, 3, 3), stride=[2, 2, 2], groups=32, bias=False
        (static_padding): ZeroPad2d((0, 1, 0, 1, 0, 1))
      )
      (_bn1): BatchNorm3d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
      (_se_reduce): Conv3dStaticSamePadding(
        32, 8, kernel_size=(1, 1, 1), stride=(1, 1, 1)
        (static_padding): Identity()
      )
      (_se_expand): Conv3dStaticSamePadding(
        8, 32, kernel_size=(1, 1, 1), stride=(1, 1, 1)
        (static_padding): Identity()
      )
      (_project_conv): Conv3dStaticSamePadding(
        32, 16, kernel_siz

In [11]:
nn.Sequential(*list(model.children())[:-4])

Sequential(
  (0): Conv3dStaticSamePadding(
    1, 40, kernel_size=(3, 3, 3), stride=(2, 2, 2), bias=False
    (static_padding): ZeroPad2d((0, 1, 0, 1, 0, 1))
  )
  (1): BatchNorm3d(40, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
  (2): ModuleList(
    (0): MBConvBlock3D(
      (_depthwise_conv): Conv3dStaticSamePadding(
        40, 40, kernel_size=(3, 3, 3), stride=[2, 2, 2], groups=40, bias=False
        (static_padding): ZeroPad2d((0, 1, 0, 1, 0, 1))
      )
      (_bn1): BatchNorm3d(40, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
      (_se_reduce): Conv3dStaticSamePadding(
        40, 10, kernel_size=(1, 1, 1), stride=(1, 1, 1)
        (static_padding): Identity()
      )
      (_se_expand): Conv3dStaticSamePadding(
        10, 40, kernel_size=(1, 1, 1), stride=(1, 1, 1)
        (static_padding): Identity()
      )
      (_project_conv): Conv3dStaticSamePadding(
        40, 24, kernel_size=(1, 1, 1), stride=

<bound method EfficientNet3D.extract_features of EfficientNet3D(
  (_conv_stem): Conv3dStaticSamePadding(
    1, 32, kernel_size=(3, 3, 3), stride=(2, 2, 2), bias=False
    (static_padding): ZeroPad2d((0, 1, 0, 1, 0, 1))
  )
  (_bn0): BatchNorm3d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
  (_blocks): ModuleList(
    (0): MBConvBlock3D(
      (_depthwise_conv): Conv3dStaticSamePadding(
        32, 32, kernel_size=(3, 3, 3), stride=[2, 2, 2], groups=32, bias=False
        (static_padding): ZeroPad2d((0, 1, 0, 1, 0, 1))
      )
      (_bn1): BatchNorm3d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
      (_se_reduce): Conv3dStaticSamePadding(
        32, 8, kernel_size=(1, 1, 1), stride=(1, 1, 1)
        (static_padding): Identity()
      )
      (_se_expand): Conv3dStaticSamePadding(
        8, 32, kernel_size=(1, 1, 1), stride=(1, 1, 1)
        (static_padding): Identity()
      )
      (_project_conv): Co

In [44]:
import torch



from torchsummary import summary
summary(model, input_size=(1, 200, 200, 200))

NameError: name 'device' is not defined

In [31]:
nn.Sequential(*list(model.children())[:])

Sequential(
  (0): ResNetFeatureMapsExtractor(
    (m): SequentialMultiOutput(
      (0): Sequential(
        (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      )
      (1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1

In [6]:
# model.features
import torch
import torch.nn as nn

nn.Sequential(*list(model.features)[:])

Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Con