In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F

class densetFPN_121(nn.Module):
    """ DenseNet121-based Feature Pyramid Network (FPN) for feature extraction. 
        Total number of parameters:  8232320 (8.23 million) """ 
    def __init__(self, weights='DEFAULT', common_channel_size=256, output_channel_size=256):
        super(densetFPN_121, self).__init__()
        original_densenet = models.densenet121(weights=weights)
        
        # Initial layers: extract features without modification
        self.encoder = nn.ModuleList([
            nn.Sequential(*list(original_densenet.features.children())[:6], nn.Dropout(0.4)),   # 128x12x12
            nn.Sequential(*list(original_densenet.features.children())[6:8], nn.Dropout(0.4)),  # 256x6x6
            nn.Sequential(*list(original_densenet.features.children())[8:10], nn.Dropout(0.4)), # 896x3x3
            nn.Sequential(*list(original_densenet.features.children())[10:-1], nn.Dropout(0.4)) # 1920x3x3
        ])
        
        # Define convolutional layers for adapting channel sizes
        fpn_channels = [128, 256, 512, 1024]
        self.adaptation_layers = nn.ModuleDict({
            f'adapt{i+1}': nn.Conv2d(fpn_channels[i], common_channel_size, kernel_size=1)
            for i in range(4)
        })

        # Define FPN layers
        self.fpn = nn.ModuleDict({
            f'fpn{i+1}': nn.Conv2d(common_channel_size, common_channel_size, kernel_size=1)
            for i in range(3)
        })

        self.merge_layers = nn.Sequential(
            nn.Conv2d(common_channel_size, output_channel_size, kernel_size=3), # kernel size 1 or 3
            nn.BatchNorm2d(output_channel_size),
            nn.ReLU(),
            nn.Dropout(0.4) # 0.2
        )

    def forward(self, x):
        # Encoder
        features = []
        for encoder in self.encoder:
            x = encoder(x)
            features.append(x)
        
        # Merge channels using 1x1 convolutions
        adapted_features = [self.adaptation_layers[f'adapt{i+1}'](features[i]) for i in range(4)]
        
        # FPN integration using top-down pathway
        fpn_output = adapted_features.pop()  # Start with the deepest features
        for i in reversed(range(3)):
            upsampled = F.interpolate(fpn_output, size=adapted_features[i].shape[-2:], mode='nearest')
            fpn_output = self.fpn[f'fpn{i+1}'](upsampled + adapted_features[i])
        
        # Merge features
        merged_features = self.merge_layers(fpn_output)
        
        return merged_features


model = densetFPN_121(weights=None)

# Print total number of parameters
total_params = sum(p.numel() for p in model.parameters())
print("Total number of parameters: ", total_params)

# Create a dummy input tensor of size [50, 3, 100, 100]
dummy_input = torch.randn(50, 3, 100, 100)

# Forward pass through the model with dummy input
features = model(dummy_input)

# Print output shapes to verify
print("Features shape:", features.shape)

# print(features.conv_info())

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

class efficientDecoder_v2_s(nn.Module):
    def __init__(self, num_channels=[24, 48, 64, 128, 160, 256], output_channel_size=256, output_feature_size=25):
        super(efficientDecoder_v2_s, self).__init__()
        # Load EfficientNet V2 Small features
        efficientnet_v2_s = models.efficientnet_v2_s(weights='DEFAULT').features[:-1]

        # Modularize encoders
        self.encoders = nn.ModuleList([
            nn.Sequential(*list(efficientnet_v2_s.children())[:2], nn.Dropout(0.1)),    # 24x50x50
            nn.Sequential(*list(efficientnet_v2_s.children())[2:3], nn.Dropout(0.1)),   # 48x25x25
            nn.Sequential(*list(efficientnet_v2_s.children())[3:4], nn.Dropout(0.2)),   # 64x13x13
            nn.Sequential(*list(efficientnet_v2_s.children())[4:5], nn.Dropout(0.2)),   # 128x7x7
            nn.Sequential(*list(efficientnet_v2_s.children())[5:6], nn.Dropout(0.3)),   # 160x7x7 # TODO: Check whether to skip 128x7x7
            nn.Sequential(*list(efficientnet_v2_s.children())[6:7], nn.Dropout(0.3))    # 256x4x4
        ])
        
        # Modularize upconvolutions
        self.upconvs = nn.ModuleList([
            nn.ConvTranspose2d(in_channels=256, out_channels=160, kernel_size=2, stride=2, padding=1, output_padding=1),
            nn.Conv2d(in_channels=160, out_channels=128, kernel_size=1, stride=1),
            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2, padding=1, output_padding=1),
            nn.ConvTranspose2d(in_channels=64, out_channels=48, kernel_size=2, stride=2, padding=1, output_padding=1),
            nn.ConvTranspose2d(in_channels=48, out_channels=24, kernel_size=2, stride=2)
        ])
        
        # Modularize decoders
        self.decoders = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(num_channels[i] * 2, num_channels[i], kernel_size=3, padding=1),
                nn.BatchNorm2d(num_channels[i]),
                nn.SiLU(),
                nn.Dropout(0.1 + i * 0.05)
            ) for i in range(len(num_channels)-1)
        ])

        # Optional, merge layers to increase the number of channels
        self.merge_layers = nn.Sequential(
            nn.Conv2d(24, output_channel_size, kernel_size=1),
            nn.BatchNorm2d(output_channel_size),
            nn.SiLU(),
            nn.Dropout(0.3)
        )
        
        self.global_avg_pool = nn.AdaptiveAvgPool2d(output_feature_size) # to reduce noise and overfitting

    def forward(self, x):
        # Encoder
        features = []
        for encoder in self.encoders:
            x = encoder(x)
            features.append(x)
        
        # Decoder
        x = features.pop()
        for upconv, decoder, feature in zip(self.upconvs, reversed(self.decoders), reversed(features)):
            x = upconv(x)
            x = torch.cat((x, feature), dim=1)
            x = decoder(x)
        
        x = self.merge_layers(x) # Introduced to increase the number of channels
        pooled_features = self.global_avg_pool(x) # Introduced to reduce noise and overfitting
        
        return pooled_features

model = efficientDecoder_v2_s()

# Print total number of parameters
total_params = sum(p.numel() for p in model.parameters())
print("Total number of parameters: ", total_params)

# Create a dummy input tensor of size [50, 3, 100, 100]
dummy_input = torch.randn(50, 3, 100, 100)

# Forward pass through the model with dummy input
features = model(dummy_input)

# Print output shapes to verify
print("Features shape:", features.shape)

In [None]:
for i in reversed(range(1,6)):
    print(i)

In [None]:
import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F

class efficientDecoder_v2_s(nn.Module):
    def __init__(self, output_channel_size=256, output_feature_size=25):
        super(efficientDecoder_v2_s, self).__init__()
        # Load EfficientNet V2 Small features
        efficientnet_v2_s = models.efficientnet_v2_s(weights='DEFAULT').features[:-1]

        # Modularize encoders
        self.encoders = nn.ModuleList([
            nn.Sequential(*list(efficientnet_v2_s.children())[:2], nn.Dropout(0.1)),    # 24x50x50
            nn.Sequential(*list(efficientnet_v2_s.children())[2:3], nn.Dropout(0.1)),   # 48x25x25
            nn.Sequential(*list(efficientnet_v2_s.children())[3:4], nn.Dropout(0.2)),   # 64x13x13
            nn.Sequential(*list(efficientnet_v2_s.children())[4:5], nn.Dropout(0.2)),   # 128x7x7
            nn.Sequential(*list(efficientnet_v2_s.children())[5:6], nn.Dropout(0.3)),   # 160x7x7
            nn.Sequential(*list(efficientnet_v2_s.children())[6:7], nn.Dropout(0.3))    # 256x4x4
        ])
        
        # Modularize upconvolutions
        self.upconvs = nn.ModuleList([
            nn.ConvTranspose2d(in_channels=256, out_channels=160, kernel_size=2, stride=2, padding=1, output_padding=1),
            nn.Conv2d(in_channels=160, out_channels=128, kernel_size=1, stride=1),
            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2, padding=1, output_padding=1),
            nn.ConvTranspose2d(in_channels=64, out_channels=48, kernel_size=2, stride=2, padding=1, output_padding=1),
            nn.ConvTranspose2d(in_channels=48, out_channels=24, kernel_size=2, stride=2)
        ])
        
        # Modularize decoders
        self.decoders = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(160*2, 160, kernel_size=3, padding=1),
                nn.BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True),
                nn.SiLU(inplace=True),
                nn.Dropout(0.3, inplace=True)
            ),
            nn.Sequential(
                nn.Conv2d(128*2, 128, kernel_size=3, padding=1),
                nn.BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True),
                nn.SiLU(inplace=True),
                nn.Dropout(0.3, inplace=True)
            ),
            nn.Sequential(
                nn.Conv2d(64*2, 64, kernel_size=3, padding=1),
                nn.BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True),
                nn.SiLU(inplace=True),
                nn.Dropout(0.2, inplace=True)
            ),
            nn.Sequential(
                nn.Conv2d(48*2, 48, kernel_size=3, padding=1),
                nn.BatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True),
                nn.SiLU(inplace=True),
                nn.Dropout(0.2, inplace=True)
            ),
            nn.Sequential(
                nn.Conv2d(24*2, 24, kernel_size=3, padding=1),
                nn.BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True),
                nn.SiLU(inplace=True),
                nn.Dropout(0.1, inplace=True)
            )
        ])

        # Optional, merge layers to increase the number of channels
        self.merge_layers = nn.Sequential(
            nn.Conv2d(24, output_channel_size, kernel_size=1),
            nn.BatchNorm2d(output_channel_size),
            nn.SiLU(),
            nn.Dropout(0.3)
        )
        
        self.global_avg_pool = nn.AdaptiveAvgPool2d(output_feature_size) # to reduce noise and overfitting

    def forward(self, x):
        # Encoder
        features = []
        for encoder in self.encoders:
            x = encoder(x)
            features.append(x)
        
        # Decoder
        x = features.pop()
        for upconv, decoder, feature in zip(self.upconvs, self.decoders, reversed(features)):
            x = upconv(x)
            x = torch.cat((x, feature), dim=1)
            x = decoder(x)
        
        x = self.merge_layers(x) # Introduced to increase the number of channels
        pooled_features = self.global_avg_pool(x) # Introduced to reduce noise and overfitting
        
        return pooled_features

model = efficientDecoder_v2_s()

# Print total number of parameters
total_params = sum(p.numel() for p in model.parameters())
print("Total number of parameters: ", total_params)

# Example initialization and forward pass
# model = EfficientDecoder_v2_s()
dummy_input = torch.randn(10, 3, 100, 100)  # Adjust size according to your actual input
output = model(dummy_input)
print(f"Output shape: {output.shape}")

In [1]:
input_dim=(128,12,12)

input_channels, input_height, input_width = input_dim

print(input_channels, input_height, input_width)

128 12 12


In [9]:
import torch
from src.models.baseline_models import construct_baselineModel, construct_baseModel, BaseModel
from src.models.backbone_models import densetFPN_121, densetFPN_201

# model = densetFPN_121()
# Create the model instance
# model = BaseModel(backbone=densetFPN_121, weights='DEFAULT', input_dim=(256,12,12))
model = construct_baseModel(backbone_name='densetFPN_121', weights='DEFAULT', input_dim=(256,12,12))

# Print total number of parameters
total_params = sum(p.numel() for p in model.parameters())
print("Total number of parameters: ", total_params)

# Create a dummy input tensor of size [50, 3, 100, 100]
dummy_input = torch.randn(50, 3, 100, 100)

# Forward pass through the model with dummy input
features = model(dummy_input)

# Print output shapes to verify
print("Features shape:", features.shape)

Total number of parameters:  45460865
torch.Size([50, 3, 100, 100])
torch.Size([50, 256, 12, 12])
Features shape: torch.Size([50, 1])


In [2]:
features

(tensor([[0.5106],
         [0.5035],
         [0.5488],
         [0.7421],
         [0.5720],
         [0.3836],
         [0.5911],
         [0.4598],
         [0.5526],
         [0.5835],
         [0.4567],
         [0.5489],
         [0.5642],
         [0.4666],
         [0.4395],
         [0.6266],
         [0.4210],
         [0.5086],
         [0.5422],
         [0.5515],
         [0.5906],
         [0.6073],
         [0.4439],
         [0.5127],
         [0.3540],
         [0.6443],
         [0.3925],
         [0.5958],
         [0.4544],
         [0.4678],
         [0.4648],
         [0.5237],
         [0.6828],
         [0.5871],
         [0.5303],
         [0.5280],
         [0.5333],
         [0.6036],
         [0.6479],
         [0.5992],
         [0.5446],
         [0.5203],
         [0.4284],
         [0.5404],
         [0.3937],
         [0.5419],
         [0.5040],
         [0.5163],
         [0.6542],
         [0.6807]], grad_fn=<SigmoidBackward0>),
 [tensor([[0.3055],


In [None]:
# Define the dictionary outside the class
MODEL_DICT = {
    'densetFPN_121': densetFPN_121,
    'densetFPN_201': densetFPN_201,
    'efficientFPN_v2_s': efficientFPN_v2_s,
    'efficientDecoder_v2_s': efficientDecoder_v2_s
}

class BaselineModel(nn.Module):
    def __init__(self, backbone, num_tasks=5, feature_dim=(256, 25, 25)):
        super(BaselineModel, self).__init__()
        self.backbone = backbone()  # Instantiate the backbone passed as a class
        
        feature_channels, feature_width, feature_height = feature_dim
        
        self.adaptive_pool = nn.AdaptiveAvgPool2d((feature_width, feature_height))
        
        self.task_specific_layers = nn.ModuleList([
            nn.Sequential(
                nn.Flatten(),
                nn.Linear(feature_channels * feature_width * feature_height, 1024),
                nn.BatchNorm1d(1024),
                nn.ReLU(),
                nn.Dropout(0.2)
            ) for _ in range(num_tasks)
        ])
        
        self.task_specific_classifier = nn.ModuleList([
            nn.Linear(1024, 1) for _ in range(num_tasks)
        ])
        
        self.final_classifier = nn.Sequential(
            nn.Linear(1024 * num_tasks, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(1024, 1)
        )
        
    def forward(self, x):
        x = self.backbone(x)
        x = self.adaptive_pool(x)
        
        intermediate_outputs = [layer(x) for layer in self.task_specific_layers]
        concatenated_outputs = torch.cat(intermediate_outputs, dim=1)
        
        task_outputs = [torch.sigmoid(classifier(io)) for io, classifier in zip(intermediate_outputs, self.task_specific_classifier)]
        
        final_output = torch.sigmoid(self.final_classifier(concatenated_outputs))
        
        return final_output, task_outputs

def construct_baselineModel(model_name, num_tasks=5, feature_dim=(256, 25, 25)):
    if model_name not in MODEL_DICT:
        raise ValueError(f"Unsupported model name {model_name}")
    backbone = MODEL_DICT[model_name]
    return BaselineModel(backbone, num_tasks, feature_dim)

In [5]:
import torch
from torchvision import models
from src.models.backbone_models import denseFPN_121, denseFPN_201

# Create the model instance
model = models.efficientnet_v2_s(weights=None).features[:-1]

print(model)

denseFPN_121(
  (encoder): ModuleList(
    (0): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): _DenseBlock(
        (denselayer1): _DenseLayer(
          (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (denselayer2): _DenseLayer(
          (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_runni

In [None]:
import torch
from src.models.backbone_models import denseFPN_121, denseFPN_201

# Create the model instance
model = denseFPN_121()

# Print total number of parameters
total_params = sum(p.numel() for p in model.parameters())
print("Total number of parameters: ", total_params)

# Create a dummy input tensor of size [50, 3, 100, 100]
dummy_input = torch.randn(50, 3, 100, 100)

# Forward pass through the model with dummy input
features = model(dummy_input)

# Print output shapes to verify
print("Features shape:", features.shape)

print("Output channels:", model.get_output_channels())

  warn(


Total number of parameters:  7643776
Features shape: torch.Size([50, 256, 12, 12])


KeyError: -1

In [17]:
import time
import torch

from helpers import list_of_distances, make_one_hot

def _train_or_test(model, dataloader, optimizer=None, class_specific=True, use_l1_mask=True,
                   coefs=None, log=print):
    '''
    model: the multi-gpu model
    dataloader:
    optimizer: if None, will be test evaluation
    '''
    is_train = optimizer is not None
    start = time.time()
    n_examples = 0
    n_correct = 0
    n_batches = 0
    total_cross_entropy = 0
    total_cluster_cost = 0
    # separation cost is meaningful only for class_specific
    total_separation_cost = 0
    total_avg_separation_cost = 0

    for i, (image, label) in enumerate(dataloader):
        input = image.cuda()
        target = label.cuda()

        # torch.enable_grad() has no effect outside of no_grad()
        grad_req = torch.enable_grad() if is_train else torch.no_grad()
        with grad_req:
            # nn.Module has implemented __call__() function
            # so no need to call .forward
            output, min_distances = model(input)

            # compute loss
            cross_entropy = torch.nn.functional.cross_entropy(output, target)

            if class_specific:
                max_dist = (model.module.prototype_shape[1]
                            * model.module.prototype_shape[2]
                            * model.module.prototype_shape[3])

                # prototypes_of_correct_class is a tensor of shape batch_size * num_prototypes
                # calculate cluster cost
                prototypes_of_correct_class = torch.t(model.module.prototype_class_identity[:,label]).cuda()
                inverted_distances, _ = torch.max((max_dist - min_distances) * prototypes_of_correct_class, dim=1)
                cluster_cost = torch.mean(max_dist - inverted_distances)

                # calculate separation cost
                prototypes_of_wrong_class = 1 - prototypes_of_correct_class
                inverted_distances_to_nontarget_prototypes, _ = \
                    torch.max((max_dist - min_distances) * prototypes_of_wrong_class, dim=1)
                separation_cost = torch.mean(max_dist - inverted_distances_to_nontarget_prototypes)

                # calculate avg cluster cost
                avg_separation_cost = \
                    torch.sum(min_distances * prototypes_of_wrong_class, dim=1) / torch.sum(prototypes_of_wrong_class, dim=1)
                avg_separation_cost = torch.mean(avg_separation_cost)
                
                if use_l1_mask:
                    l1_mask = 1 - torch.t(model.module.prototype_class_identity).cuda()
                    l1 = (model.module.last_layer.weight * l1_mask).norm(p=1)
                else:
                    l1 = model.module.last_layer.weight.norm(p=1) 

            else:
                min_distance, _ = torch.min(min_distances, dim=1)
                cluster_cost = torch.mean(min_distance)
                l1 = model.module.last_layer.weight.norm(p=1)

            # evaluation statistics
            _, predicted = torch.max(output.data, 1)
            n_examples += target.size(0)
            n_correct += (predicted == target).sum().item()

            n_batches += 1
            total_cross_entropy += cross_entropy.item()
            total_cluster_cost += cluster_cost.item()
            total_separation_cost += separation_cost.item()
            total_avg_separation_cost += avg_separation_cost.item()

        # compute gradient and do SGD step
        if is_train:
            if class_specific:
                if coefs is not None:
                    loss = (coefs['crs_ent'] * cross_entropy
                          + coefs['clst'] * cluster_cost
                          + coefs['sep'] * separation_cost
                          + coefs['l1'] * l1)
                else:
                    loss = cross_entropy + 0.8 * cluster_cost - 0.08 * separation_cost + 1e-4 * l1
            else:
                if coefs is not None:
                    loss = (coefs['crs_ent'] * cross_entropy
                          + coefs['clst'] * cluster_cost
                          + coefs['l1'] * l1)
                else:
                    loss = cross_entropy + 0.8 * cluster_cost + 1e-4 * l1
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        del input
        del target
        del output
        del predicted
        del min_distances

    end = time.time()

    log('\ttime: \t{0}'.format(end -  start))
    log('\tcross ent: \t{0}'.format(total_cross_entropy / n_batches))
    log('\tcluster: \t{0}'.format(total_cluster_cost / n_batches))
    if class_specific:
        log('\tseparation:\t{0}'.format(total_separation_cost / n_batches))
        log('\tavg separation:\t{0}'.format(total_avg_separation_cost / n_batches))
    log('\taccu: \t\t{0}%'.format(n_correct / n_examples * 100))
    log('\tl1: \t\t{0}'.format(model.module.last_layer.weight.norm(p=1).item()))
    p = model.module.prototype_vectors.view(model.module.num_prototypes, -1).cpu()
    with torch.no_grad():
        p_avg_pair_dist = torch.mean(list_of_distances(p, p))
    log('\tp dist pair: \t{0}'.format(p_avg_pair_dist.item()))

    return n_correct / n_examples


def train(model, dataloader, optimizer, class_specific=False, coefs=None, log=print):
    assert(optimizer is not None)
    
    log('\ttrain')
    model.train()
    return _train_or_test(model=model, dataloader=dataloader, optimizer=optimizer,
                          class_specific=class_specific, coefs=coefs, log=log)


def test(model, dataloader, class_specific=False, log=print):
    log('\ttest')
    model.eval()
    return _train_or_test(model=model, dataloader=dataloader, optimizer=None,
                          class_specific=class_specific, log=log)


def last_only(model, log=print):
    for p in model.module.features.parameters():
        p.requires_grad = False
    for p in model.module.add_on_layers.parameters():
        p.requires_grad = False
    model.module.prototype_vectors.requires_grad = False
    for p in model.module.last_layer.parameters():
        p.requires_grad = True
    
    log('\tlast layer')


def warm_only(model, log=print):
    for p in model.module.features.parameters():
        p.requires_grad = False
    for p in model.module.add_on_layers.parameters():
        p.requires_grad = True
    model.module.prototype_vectors.requires_grad = True
    for p in model.module.last_layer.parameters():
        p.requires_grad = True
    
    log('\twarm')


def joint(model, log=print):
    for p in model.module.features.parameters():
        p.requires_grad = True
    for p in model.module.add_on_layers.parameters():
        p.requires_grad = True
    model.module.prototype_vectors.requires_grad = True
    for p in model.module.last_layer.parameters():
        p.requires_grad = True
    
    log('\tjoint')

Total number of parameters:  15688838
Backbone output shape: torch.Size([50, 512, 1, 1])
Prototype vectors: 50
Unsqueeze Prototype 0 shape: torch.Size([1, 512, 1, 1]), task index: 0
Distances shape: torch.Size([50, 512, 1, 1])
Processed shape: torch.Size([50, 1024])
Unsqueeze Prototype 1 shape: torch.Size([1, 512, 1, 1]), task index: 0
Distances shape: torch.Size([50, 512, 1, 1])
Processed shape: torch.Size([50, 1024])
Unsqueeze Prototype 2 shape: torch.Size([1, 512, 1, 1]), task index: 0
Distances shape: torch.Size([50, 512, 1, 1])
Processed shape: torch.Size([50, 1024])
Unsqueeze Prototype 3 shape: torch.Size([1, 512, 1, 1]), task index: 0
Distances shape: torch.Size([50, 512, 1, 1])
Processed shape: torch.Size([50, 1024])
Unsqueeze Prototype 4 shape: torch.Size([1, 512, 1, 1]), task index: 0
Distances shape: torch.Size([50, 512, 1, 1])
Processed shape: torch.Size([50, 1024])
Unsqueeze Prototype 5 shape: torch.Size([1, 512, 1, 1]), task index: 0
Distances shape: torch.Size([50, 512, 

RuntimeError: mat1 and mat2 shapes cannot be multiplied (50x51200 and 5120x1024)

In [1]:
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from torch.optim import optimizer
import torch.nn.functional as F
from sklearn.metrics import balanced_accuracy_score, f1_score, precision_score, recall_score, roc_auc_score

def _train_or_test(model, data_loader, optimizer, device, is_train=True, use_l1_mask=True, coefs=None, task_weights=None):
    model.to(device)
    if is_train:
        model.train()
    else:
        model.eval()
    
    total_loss = 0
    
    total_correct = [0] * 5  # For tasks
    total_samples = [0] * 5  # For tasks
    
    task_cross_entropy = [0.0] * 5
    task_cluster_cost = [0.0] * 5
    task_separation_cost = [0.0] * 5
    task_avg_separation_cost = [0.0] * 5
    task_l1 = [0.0] * 5
    final_pred_targets = [[] for _ in range(5)]
    final_pred_outputs = [[] for _ in range(5)]
    
    final_correct = 0  # For final output
    final_samples = 0  # For final output
    final_targets = []  # For calculating balanced accuracy for final output
    final_outputs = []  # For calculating balanced accuracy for final output
    
    n_batches = 0
    context = torch.enable_grad() if is_train else torch.no_grad()
    with context:
        for X, targets, bweights_chars, final_target, bweight in tqdm(data_loader, leave=False):
            X = X.to(device)
            bweights_chars = [b.float().to(device) for b in bweights_chars]
            
            targets2 = [F.one_hot(t.squeeze(), num_classes=2).float().to(device) for t in targets]
            targets = [t.squeeze().to(device) for t in targets]
            final_target = final_target.float().unsqueeze(1).to(device)
            bweight = bweight.float().unsqueeze(1).to(device)
            
            final_output, task_outputs, min_distances = model(X)
            
            batch_loss = 0.0
            for i, (task_output, min_distance, target, bweight_char) in enumerate(zip(task_outputs, min_distances, targets, bweights_chars)):
                # Get the prototype identity for each characteristic
                prototype_char_identity = model.prototype_class_identity[i].to(device)
                
                # Get the max distance between prototypes
                max_dist = (model.prototype_shape[1] * model.prototype_shape[2] * model.prototype_shape[3])
                
                # Compute cross entropy cost for each characteristic
                cross_entropy = torch.nn.functional.cross_entropy(task_output, target, weight=bweight_char[0])

                # Compute cluster cost for each characteristic
                prototypes_of_correct_class = torch.t(prototype_char_identity[:,target]).to(device)    # batch_size * num_prototypes
                inverted_distances, _ = torch.max((max_dist - min_distance) * prototypes_of_correct_class, dim=1)
                cluster_cost = torch.mean(max_dist - inverted_distances) # Increase the distance between the prototypes of the same class

                # Compute separation cost for each characteristic
                prototypes_of_wrong_class = 1 - prototypes_of_correct_class
                inverted_distances_to_nontarget_prototypes, _ = torch.max((max_dist - min_distance) * prototypes_of_wrong_class, dim=1)
                separation_cost = torch.mean(max_dist - inverted_distances_to_nontarget_prototypes) # Decrease the distance between the prototypes of different classes

                # Compute average separation cost for each characteristic
                avg_separation_cost = torch.sum(min_distance * prototypes_of_wrong_class, dim=1) / torch.sum(prototypes_of_wrong_class, dim=1)
                avg_separation_cost = torch.mean(avg_separation_cost)
                
                # Compute l1 regularization for each characteristic
                if use_l1_mask:
                    l1_mask = 1 - torch.t(prototype_char_identity).to(device)
                    l1 = (model.task_specific_classifier[i].weight * l1_mask).norm(p=1)
                else:
                    l1 = model.task_specific_classifier[i].weight.norm(p=1) 
                    
                # Compute accuracy for each characteristic
                preds = task_output.argmax(dim=1)
                total_correct[i] += (preds == target).sum().item()
                total_samples[i] += target.size(0)
                
                # Collect data for balanced accuracy for each characteristic
                final_pred_targets[i].extend(target.cpu().numpy())
                final_pred_outputs[i].extend(preds.detach().cpu().numpy())

                task_cross_entropy[i] += cross_entropy.item()
                task_cluster_cost[i] += cluster_cost.item()
                task_separation_cost[i] += separation_cost.item()
                task_avg_separation_cost[i] += avg_separation_cost.item()
                task_l1[i] += l1
                
                batch_loss += (coefs['crs_ent'] * cross_entropy + 
                               coefs['clst'] * cluster_cost + 
                               coefs['sep'] * separation_cost +
                               coefs['l1'] * l1)

            # Compute binary cross entropy loss for final output
            final_loss = torch.nn.functional.binary_cross_entropy(final_output, final_target, weight=bweight)
            batch_loss += final_loss
            
            # Compute statistics for final accuracy
            final_preds = final_output.round()
            final_correct += (final_preds == final_target).sum().item()
            final_samples += final_target.size(0)
            final_targets.extend(final_target.cpu().numpy())
            final_outputs.extend(final_preds.detach().cpu().numpy())
            
            total_loss += batch_loss.item()  # Sum up total loss
            
            # compute gradient and do SGD step
            if is_train:
                optimizer.zero_grad()
                batch_loss.backward()
                optimizer.step()
                
            n_batches += 1
    
    # TODO: Add the seperate characteristic losses to the return dictionary, and include final F1 score, final precision, final recall, and final AUC
    average_loss = total_loss / n_batches
    task_accuracies = [correct / samples for correct, samples in zip(total_correct, total_samples)]
    task_balanced_accuracies = [balanced_accuracy_score(targets, outputs) for targets, outputs in zip(final_pred_targets, final_pred_outputs)]
    final_accuracy = final_correct / final_samples
    final_balanced_accuracy = balanced_accuracy_score(final_targets, final_outputs)
    # final_f1 = f1_score(final_targets, final_outputs)
    # final_precision = precision_score(final_targets, final_outputs)
    # final_recall = recall_score(final_targets, final_outputs)
    # final_auc = roc_auc_score(final_targets, final_outputs)
    
    # task_cross_entropy = [t / n_batches for t in task_cross_entropy]
    # task_cluster_cost = [t / n_batches for t in task_cluster_cost]
    # task_separation_cost = [t / n_batches for t in task_separation_cost]
    # task_avg_separation_cost = [t / n_batches for t in task_avg_separation_cost]
    # task_l1 = [t / n_batches for t in task_l1]
    
    # return the metrics as a dictionary
    metrics = {'average_loss': average_loss, 
               'task_accuracies': task_accuracies, 
               'task_balanced_accuracies': task_balanced_accuracies, 
               'final_accuracy': final_accuracy, 
               'final_balanced_accuracy': final_balanced_accuracy}
               # 'task_cross_entropy': task_cross_entropy,
               # 'task_cluster_cost': task_cluster_cost,
               # 'task_separation_cost': task_separation_cost,
               # 'task_avg_separation_cost': task_avg_separation_cost,
               # 'task_l1': task_l1}
    
    if is_train:
        return metrics
    else:
        return metrics

def train_ppnet(model, data_loader, optimizer, device, use_l1_mask=True, coefs=None, task_weights=None):
    train_metrics = _train_or_test(model, data_loader, optimizer, device, is_train=True, use_l1_mask=use_l1_mask, coefs=coefs, task_weights=task_weights)
    print(f"Train loss: {train_metrics['average_loss']:.5f}")
    for i, (acc, bal_acc) in enumerate(zip(train_metrics['task_accuracies'], train_metrics['task_balanced_accuracies']), 1):
        print(f"Task {i} - Train Accuracy: {acc*100:.2f}%, Train Balanced Accuracy: {bal_acc*100:.2f}%")
    # Print the metrics for the final output
    print(f"Final Output - Train Accuracy: {train_metrics['final_accuracy']*100:.2f}%, Train Balanced Accuracy: {train_metrics['final_balanced_accuracy']*100:.2f}%")
    return train_metrics

def test_ppnet(model, data_loader, device, use_l1_mask=True, coefs=None, task_weights=None):
    test_metrics = _train_or_test(model, data_loader, None, device, is_train=False, use_l1_mask=use_l1_mask, coefs=coefs, task_weights=task_weights)
    print(f"Test loss: {test_metrics['average_loss']:.5f}")
    for i, (acc, bal_acc) in enumerate(zip(test_metrics['task_accuracies'], test_metrics['task_balanced_accuracies']), 1):
        print(f"Task {i} - Test Accuracy: {acc*100:.2f}%, Test Balanced Accuracy: {bal_acc*100:.2f}%")
    # Print the metrics for the final output
    print(f"Final Output - Test Accuracy: {test_metrics['final_accuracy']*100:.2f}%, Test Balanced Accuracy: {test_metrics['final_balanced_accuracy']*100:.2f}%")
    return test_metrics
            
def last_only(model):
    for p in model.features.parameters():
        p.requires_grad = False
    for p in model.add_on_layers.parameters():
        p.requires_grad = False
    model.prototype_vectors.requires_grad = False
    for p in model.task_specific_classifier.parameters():
        p.requires_grad = True
    for p in model.final_classifier.parameters():
        p.requires_grad = True

def warm_only(model):
    for p in model.features.parameters():
        p.requires_grad = False
    for p in model.add_on_layers.parameters():
        p.requires_grad = True
    model.prototype_vectors.requires_grad = False
    for p in model.task_specific_classifier.parameters():
        p.requires_grad = True
    for p in model.final_classifier.parameters():
        p.requires_grad = True
        
def joint(model):
    for p in model.features.parameters():
        p.requires_grad = True
    for p in model.add_on_layers.parameters():
        p.requires_grad = True
    model.prototype_vectors.requires_grad = False
    for p in model.task_specific_classifier.parameters():
        p.requires_grad = True
    for p in model.final_classifier.parameters():
        p.requires_grad = True
    

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

from src.utils.receptive_field import compute_proto_layer_rf_info_v2
from src.models.backbone_models import denseFPN_121, denseFPN_201, efficientFPN_v2_s, efficientDecoder_v2_s, denseNet121, denseNet201

# Dictionary of supported backbone models
BACKBONE_DICT = {
    'denseNet121': denseNet121,
    'denseNet201': denseNet201,
    'efficientNetV2_s': models.efficientnet_v2_s(weights='DEFAULT').features[:-1],
    'denseFPN_121': denseFPN_121,
    'denseFPN_201': denseFPN_201,
    'efficientFPN_v2_s': efficientFPN_v2_s,
    'efficientDecoder_v2_s': efficientDecoder_v2_s
}

class PPNet(nn.Module):
    def __init__(self, features, img_size, prototype_shape, num_characteristics, proto_layer_rf_info=None, init_weights=True, prototype_activation_function='log', add_on_layers_type='bottleneck'):
        super(PPNet, self).__init__()
        # Define the input configurations
        self.img_size = img_size # size of the input images (e.g. (3, 224, 224))
        self.prototype_shape = prototype_shape # shape of the prototype vectors (e.g. (2000, 512, 1, 1))
        self.num_characteristics = num_characteristics # number of characteristics to predict (e.g. shape, margin, etc.)
        self.num_classes = 2 # binary classification
        
        self.num_prototypes = self.prototype_shape[0] # total number of prototypes
        self.prototypes_per_characteristic = self.num_prototypes // self.num_characteristics # number of prototypes per characteristic
        self.prototypes_per_class = self.prototypes_per_characteristic // self.num_classes # number of prototypes per class
        
        self.proto_layer_rf_info = proto_layer_rf_info
        self.epsilon = 1e-4 # small value to avoid numerical instability

        self.prototype_activation_function = prototype_activation_function # activation function for the prototypes
        
        self.prototype_class_identity = self._get_prototype_class_identity() # class identity of the prototypes
        
        # Define the feature extractor
        self.features = features
        
        # Define the add-on layers
        first_add_on_layer_in_channels = features.get_output_channels()
        # self.add_on_layers = self.initialize_add_on_layers(first_add_on_layer_in_channels, add_on_layers_type)
        if add_on_layers_type == 'bottleneck':
            add_on_layers = []
            current_in_channels = first_add_on_layer_in_channels
            while (current_in_channels > self.prototype_shape[1]) or (len(add_on_layers) == 0):
                current_out_channels = max(self.prototype_shape[1], (current_in_channels // 2))
                add_on_layers.append(nn.Conv2d(in_channels=current_in_channels,
                                               out_channels=current_out_channels,
                                               kernel_size=1))
                add_on_layers.append(nn.ReLU())
                add_on_layers.append(nn.Conv2d(in_channels=current_out_channels,
                                               out_channels=current_out_channels,
                                               kernel_size=1))
                if current_out_channels > self.prototype_shape[1]:
                    add_on_layers.append(nn.ReLU())
                else:
                    assert(current_out_channels == self.prototype_shape[1])
                    add_on_layers.append(nn.Sigmoid())
                current_in_channels = current_in_channels // 2
            self.add_on_layers = nn.Sequential(*add_on_layers)
        else:
            self.add_on_layers = nn.Sequential(
                nn.Conv2d(in_channels=first_add_on_layer_in_channels, out_channels=self.prototype_shape[1], kernel_size=1),
                nn.BatchNorm2d(self.prototype_shape[1]),
                nn.ReLU(),
                nn.Dropout(0.2),
                nn.Conv2d(in_channels=self.prototype_shape[1], out_channels=self.prototype_shape[1], kernel_size=1),
                nn.Sigmoid()
                )
        
        # Define separate prototype vectors for each characteristic
        self.prototype_vectors = nn.ParameterList([
            nn.Parameter(torch.rand(self.prototypes_per_characteristic, prototype_shape[1], prototype_shape[2], prototype_shape[3]), requires_grad=True) for _ in range(self.num_characteristics)
        ])
        
        # Define a tensor of ones for the l2-convolution
        self.ones = nn.Parameter(torch.ones(self.prototypes_per_characteristic, prototype_shape[1], prototype_shape[2], prototype_shape[3]), requires_grad=False)

        # Define a separate classifier for each characteristic
        self.task_specific_classifier = nn.ModuleList([
            nn.Linear(self.prototypes_per_characteristic, self.num_classes) for _ in range(self.num_characteristics)   # Apply softmax to get confidence scores for each class of each characteristic
        ])
        
        self.final_classifier = nn.Sequential(
            nn.Flatten(),
            # nn.Linear(self.num_characteristics*self.prototypes_per_characteristic*12*12, self.num_characteristics*self.num_classes), # HxW is the output size of the feature extractor
            # nn.BatchNorm1d(self.num_characteristics*self.num_classes),
            # nn.ReLU(),
            # nn.Dropout(0.2),
            nn.Linear(self.num_characteristics*self.prototypes_per_characteristic, 1)
        )

        if init_weights:
            self._initialize_weights()
            self._set_last_layer_incorrect_connection(-0.5)
    
    def _get_prototype_class_identity(self):
        """
        Initialize the class identities of the prototypes structured by characteristics.
        Each characteristic has a tensor of size [num_prototypes_per_characteristic, num_classes].
        """
        prototype_class_identity = []
        num_prototypes_per_class = self.prototypes_per_characteristic // self.num_classes
        
        # Create a separate class identity matrix for each characteristic
        for _ in range(self.num_characteristics):
            # Initialize a zero matrix for current characteristic
            class_identity = torch.zeros(self.prototypes_per_characteristic, self.num_classes)
            
            # Assign prototypes to each class (binary: two classes per characteristic)
            for j in range(self.prototypes_per_characteristic):
                class_index = j // num_prototypes_per_class
                class_identity[j, class_index] = 1
            
            prototype_class_identity.append(class_identity)
        
        return prototype_class_identity
    
    def _set_last_layer_incorrect_connection(self, incorrect_strength):
        '''
        the incorrect strength will be actual strength if -0.5 then input -0.5
        '''
        for i in range(self.num_characteristics):
            positive_one_weights_locations = torch.t(self.prototype_class_identity[i])
            negative_one_weights_locations = 1 - positive_one_weights_locations
            
            correct_class_connection = 1
            incorrect_class_connection = incorrect_strength
            self.task_specific_classifier[i].weight.data.copy_(
                correct_class_connection * positive_one_weights_locations
                + incorrect_class_connection * negative_one_weights_locations)

    def _initialize_weights(self):
        for m in self.add_on_layers.modules():
            if isinstance(m, nn.Conv2d):
                # every init technique has an underscore _ in the name
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def _l2_convolution(self, x, prototype_vector):
        '''
        apply prototype_vector as l2-convolution filters on input x
        '''
        x2 = x ** 2
        x2_patch_sum = F.conv2d(input=x2, weight=self.ones)

        p2 = prototype_vector ** 2
        p2 = torch.sum(p2, dim=(1, 2, 3))
        # p2 is a vector of shape (num_prototypes_per_characteristic,)
        # then we reshape it to (num_prototypes_per_characteristic, 1, 1)
        p2_reshape = p2.view(-1, 1, 1)

        xp = F.conv2d(input=x, weight=prototype_vector)
        intermediate_result = - 2 * xp + p2_reshape  # use broadcast
        # x2_patch_sum and intermediate_result are of the same shape
        distances = F.relu(x2_patch_sum + intermediate_result)

        return distances

    def distance_2_similarity(self, distances):
        if self.prototype_activation_function == 'log':
            return torch.log((distances + 1) / (distances + self.epsilon))
        elif self.prototype_activation_function == 'linear':
            return -distances
        else:
            return self.prototype_activation_function(distances)
        
    def forward(self, x):
        # Extract features using the backbone
        x = self.features(x) # B x 256 x H x W
        
        # Apply add-on layers to the features
        x = self.add_on_layers(x) # B x 512 x H x W
        
        # Compute distances and task logits for each characteristic
        task_logits = []
        task_probabilities = []
        similarities = []
        min_distances = []
        for i in range(self.num_characteristics):
            distance = self._l2_convolution(x, self.prototype_vectors[i]) # B x num_prototypes_per_characteristic x H x W, B x num_prototypes_per_characteristic x 1 x 1
            min_distance = -F.max_pool2d(-distance, kernel_size=(distance.size()[2], distance.size()[3])) # B x num_prototypes_per_characteristic x 1 x 1
            min_distance = min_distance.view(-1, self.prototypes_per_characteristic) # B x num_prototypes_per_characteristic
            similarity = self.distance_2_similarity(min_distance) # B x num_prototypes_per_characteristic
            task_logit = self.task_specific_classifier[i](similarity) # B x 2
            task_probability = F.softmax(task_logit, dim=1)
                        
            similarities.append(similarity)
            min_distances.append(min_distance)
            task_logits.append(task_logit)
            task_probabilities.append(task_probability)
        
        # Concatenate task distances for the final classifier
        # final_output = torch.sigmoid(self.final_classifier(torch.cat(distances, dim=1))) # TODO: Use intermediate task logits instead of distances and feature extractor output
        print(torch.cat(similarities, dim=1).shape)
        final_output = torch.sigmoid(self.final_classifier(torch.cat(similarities, dim=1)))
        return final_output, task_logits, min_distances
                    
    def push_forward(self, x):
        '''this method is needed for the pushing operation'''
        x = self.features(x)
        x = self.add_on_layers(x)
        distances = []
        for i in range(self.num_characteristics):
            distance = self._l2_convolution(x, self.prototype_vectors[i])
            distances.append(distance)
        return x, distances
    
    
    # TODO: Implement the pruning operation


def construct_PPNet(
    base_architecture='denseNet121', 
    weights='DEFAULT', 
    img_size=224,
    prototype_shape=(50*5*2, 224, 1, 1), 
    num_characteristics=5,
    prototype_activation_function='log',
    add_on_layers_type='bottleneck'
):
    
    features = BACKBONE_DICT[base_architecture](weights=weights)
    
    layer_filter_sizes, layer_strides, layer_paddings = features.conv_info()
    
    proto_layer_rf_info = compute_proto_layer_rf_info_v2(
        img_size=img_size,  
        layer_filter_sizes=layer_filter_sizes,
        layer_strides=layer_strides,
        layer_paddings=layer_paddings,
        prototype_kernel_size=prototype_shape[2]
    )
    
    return PPNet(
        features=features,
        img_size=img_size,
        prototype_shape=prototype_shape,
        num_characteristics=num_characteristics,
        proto_layer_rf_info=proto_layer_rf_info,
        init_weights=True,
        prototype_activation_function=prototype_activation_function,
        add_on_layers_type=add_on_layers_type
    )
    
# Set torch seed for reproducibility
torch.manual_seed(42)

# model = PPNet(features=denseFPN_121(), img_size=(3, 100, 100), prototype_shape=(10*5*2, 224, 1, 1), num_characteristics=5, init_weights=True, prototype_activation_function='log', add_on_layers_type='bottleneck')

model = construct_PPNet(base_architecture='denseFPN_121', weights='DEFAULT', img_size=100, prototype_shape=(10*5*2, 224, 1, 1), num_characteristics=5, prototype_activation_function='log', add_on_layers_type='bottleneck')

# Print total number of parameters
total_params = sum(p.numel() for p in model.parameters())
print("Total number of parameters: ", total_params)

# Create a dummy input tensor of size [50, 3, 100, 100]
dummy_input = torch.randn(10, 3, 100, 100)

# Forward pass through the model with dummy input
final_output, task_logits, min_distances = model(dummy_input)

# Print output shapes to verify
print("task_logits shape:", [task_logit.shape for task_logit in task_logits])  
print("final_output shape:", final_output.shape)

Total number of parameters:  7778845
torch.Size([10, 20])
torch.Size([10, 20])
torch.Size([10, 20])
torch.Size([10, 20])
torch.Size([10, 20])
torch.Size([10, 100])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (10x100 and 10x1)

In [2]:
from src.loaders.dataloader import LIDCDataset
import torchvision.transforms as transforms
from PIL import Image
import os
import torch

IMG_CHANNELS = 3
IMG_SIZE = 100
CHOSEN_CHARS = [False, True, False, True, True, False, False, True, True]
labels_file = './dataset/Meta/meta_info_old.csv'

mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)
preprocess = transforms.Compose([
    transforms.Grayscale(num_output_channels=3), 
    transforms.Resize(256),  # First resize to larger dimensions
    transforms.CenterCrop(224),  # Then crop to 224x224
    transforms.ToTensor(),  # Convert to tensor (also scales pixel values to [0, 1])
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# train set
LIDC_trainset = LIDCDataset(labels_file=labels_file, chosen_chars=CHOSEN_CHARS, auto_split=True, zero_indexed=False, 
                                                        transform=transforms.Compose([transforms.Grayscale(num_output_channels=IMG_CHANNELS), 
                                                                    transforms.Resize(size=(IMG_SIZE, IMG_SIZE), interpolation=Image.BILINEAR), 
                                                                    transforms.ToTensor(), 
                                                                    # transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
                                                                    transforms.Normalize(mean, std)
                                                                    ]),
                                                        train=True)
train_dataloader = torch.utils.data.DataLoader(LIDC_trainset, batch_size=50, shuffle=True, num_workers=4)
# test set
LIDC_testset = LIDCDataset(labels_file=labels_file, chosen_chars=CHOSEN_CHARS, auto_split=True, zero_indexed=False, 
                                                        transform=transforms.Compose([transforms.Grayscale(num_output_channels=IMG_CHANNELS), 
                                                                    transforms.Resize(size=(IMG_SIZE, IMG_SIZE), interpolation=Image.BILINEAR), 
                                                                    transforms.ToTensor(), 
                                                                    # transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
                                                                    transforms.Normalize(mean, std)
                                                                    ]), 
                                                        train=False)
test_dataloader = torch.utils.data.DataLoader(LIDC_testset, batch_size=50, shuffle=True, num_workers=4)


  warn(


In [3]:
from src.training.train_ppnet import train_ppnet, test_ppnet
from src.models.ProtoPNet import construct_PPNet
import torch

epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = construct_PPNet(base_architecture='denseFPN_121', weights='DEFAULT', img_size=100, prototype_shape=(10*5*2, 224, 1, 1), num_characteristics=5, prototype_activation_function='log', add_on_layers_type='bottleneck')

# Define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
coefs = {
    'crs_ent': 1,#0.4,#1.1,#0.8,#1.1,#1,#0.5,#changed from 1 at 48
    'clst': 0.8*1.5,#0.2,#0.3,#1.1,#0.8,#5,#0.8,
    'sep': -0.0004,#-0.0004,#-0.17,#-0.22,#-0.5, #used to be -0.08 #dm made it smaller to avoid the problem I noticed where as the separation loss is subtracted, having it too large makes your loss negative making everything explode (also -0.025 works but unstable)
    'l1': 1e-4
}

task_weights = [1.0 / 5] * 5
for epoch in range(epochs):
    print("\n" + "-"*100 + f"\nEpoch: {epoch + 1}/{epochs},\t" + f"Task Weights: {[f'{weight:.2f}' for weight in task_weights]}\n" + "-"*100)
    train_metrics,task_weights = train_ppnet(model, train_dataloader, optimizer, device, coefs=coefs, task_weights=task_weights)
    test_metrics = test_ppnet(model, test_dataloader, device, coefs=coefs)



----------------------------------------------------------------------------------------------------
Epoch: 1/10,	Task Weights: ['0.20', '0.20', '0.20', '0.20', '0.20']
----------------------------------------------------------------------------------------------------


  0%|          | 0/237 [00:00<?, ?it/s]

0.2
0.2
0.2
0.2
0.2


  1%|          | 2/237 [00:01<02:44,  1.43it/s]

0.2
0.2
0.2
0.2
0.2


  1%|▏         | 3/237 [00:01<02:08,  1.82it/s]

0.2
0.2
0.2
0.2
0.2


  2%|▏         | 4/237 [00:02<01:53,  2.05it/s]

0.2
0.2
0.2
0.2
0.2


  2%|▏         | 5/237 [00:02<01:43,  2.23it/s]

0.2
0.2
0.2
0.2
0.2


  3%|▎         | 6/237 [00:03<01:36,  2.39it/s]

0.2
0.2
0.2
0.2
0.2


  3%|▎         | 7/237 [00:03<01:32,  2.49it/s]

0.2
0.2
0.2
0.2
0.2


  3%|▎         | 8/237 [00:03<01:25,  2.68it/s]

0.2
0.2
0.2
0.2
0.2


  4%|▍         | 9/237 [00:04<01:24,  2.71it/s]

0.2
0.2
0.2
0.2
0.2


  4%|▍         | 10/237 [00:04<01:32,  2.45it/s]

0.2
0.2
0.2
0.2
0.2


  5%|▍         | 11/237 [00:04<01:29,  2.52it/s]

0.2
0.2
0.2
0.2
0.2


  5%|▌         | 12/237 [00:05<01:24,  2.65it/s]

0.2
0.2
0.2
0.2
0.2


  5%|▌         | 13/237 [00:05<01:23,  2.68it/s]

0.2
0.2
0.2
0.2
0.2


  6%|▌         | 14/237 [00:05<01:20,  2.77it/s]

0.2
0.2
0.2
0.2
0.2


  6%|▋         | 15/237 [00:06<01:20,  2.77it/s]

0.2
0.2
0.2
0.2
0.2


  7%|▋         | 16/237 [00:06<01:18,  2.80it/s]

0.2
0.2
0.2
0.2
0.2


  7%|▋         | 17/237 [00:07<01:17,  2.83it/s]

0.2
0.2
0.2
0.2
0.2


  8%|▊         | 18/237 [00:07<01:17,  2.82it/s]

0.2
0.2
0.2
0.2
0.2


  8%|▊         | 19/237 [00:07<01:16,  2.84it/s]

0.2
0.2
0.2
0.2
0.2


  8%|▊         | 20/237 [00:08<01:15,  2.86it/s]

0.2
0.2
0.2
0.2
0.2


  9%|▉         | 21/237 [00:08<01:25,  2.54it/s]

0.2
0.2
0.2
0.2
0.2


  9%|▉         | 22/237 [00:08<01:21,  2.63it/s]

0.2
0.2
0.2
0.2
0.2


 10%|▉         | 23/237 [00:09<01:18,  2.73it/s]

0.2
0.2
0.2
0.2
0.2


 10%|█         | 24/237 [00:09<01:17,  2.75it/s]

0.2
0.2
0.2
0.2
0.2


 11%|█         | 25/237 [00:10<01:19,  2.68it/s]

0.2
0.2
0.2
0.2
0.2


 11%|█         | 26/237 [00:10<01:24,  2.49it/s]

0.2
0.2
0.2
0.2
0.2


 11%|█▏        | 27/237 [00:10<01:22,  2.56it/s]

0.2
0.2
0.2
0.2
0.2


 12%|█▏        | 28/237 [00:11<01:21,  2.55it/s]

0.2
0.2
0.2
0.2
0.2


 12%|█▏        | 29/237 [00:11<01:21,  2.55it/s]

0.2
0.2
0.2
0.2
0.2


 13%|█▎        | 30/237 [00:12<01:30,  2.29it/s]

0.2
0.2
0.2
0.2
0.2


 13%|█▎        | 31/237 [00:12<01:22,  2.49it/s]

0.2
0.2
0.2
0.2
0.2


 14%|█▎        | 32/237 [00:12<01:21,  2.52it/s]

0.2
0.2
0.2
0.2
0.2


 14%|█▍        | 33/237 [00:13<01:17,  2.65it/s]

0.2
0.2
0.2
0.2
0.2


 14%|█▍        | 34/237 [00:13<01:15,  2.70it/s]

0.2
0.2
0.2
0.2
0.2


 15%|█▍        | 35/237 [00:14<01:18,  2.56it/s]

0.2
0.2
0.2
0.2
0.2


 15%|█▌        | 36/237 [00:14<01:18,  2.57it/s]

0.2
0.2
0.2
0.2
0.2


 16%|█▌        | 37/237 [00:14<01:17,  2.58it/s]

0.2
0.2
0.2
0.2
0.2


 16%|█▌        | 38/237 [00:15<01:17,  2.58it/s]

0.2
0.2
0.2
0.2
0.2


 16%|█▋        | 39/237 [00:15<01:15,  2.61it/s]

0.2
0.2
0.2
0.2
0.2


 17%|█▋        | 40/237 [00:15<01:14,  2.63it/s]

0.2
0.2
0.2
0.2
0.2


 17%|█▋        | 41/237 [00:16<01:29,  2.19it/s]

0.2
0.2
0.2
0.2
0.2


 18%|█▊        | 42/237 [00:16<01:25,  2.28it/s]

0.2
0.2
0.2
0.2
0.2


 18%|█▊        | 43/237 [00:17<01:20,  2.41it/s]

0.2
0.2
0.2
0.2
0.2


 19%|█▊        | 44/237 [00:17<01:18,  2.47it/s]

0.2
0.2
0.2
0.2
0.2


 19%|█▉        | 45/237 [00:18<01:19,  2.43it/s]

0.2
0.2
0.2
0.2
0.2


 19%|█▉        | 46/237 [00:18<01:20,  2.38it/s]

0.2
0.2
0.2
0.2
0.2


 20%|█▉        | 47/237 [00:18<01:16,  2.49it/s]

0.2
0.2
0.2
0.2
0.2


 20%|██        | 48/237 [00:19<01:14,  2.55it/s]

0.2
0.2
0.2
0.2
0.2


 21%|██        | 49/237 [00:19<01:12,  2.60it/s]

0.2
0.2
0.2
0.2
0.2


 21%|██        | 50/237 [00:20<01:19,  2.35it/s]

0.2
0.2
0.2
0.2
0.2


 22%|██▏       | 51/237 [00:20<01:18,  2.36it/s]

0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2


 22%|██▏       | 53/237 [00:21<01:29,  2.06it/s]

0.2
0.2
0.2
0.2
0.2


 23%|██▎       | 54/237 [00:22<01:23,  2.19it/s]

0.2
0.2
0.2
0.2
0.2


 23%|██▎       | 55/237 [00:22<01:17,  2.36it/s]

0.2
0.2
0.2
0.2
0.2


 24%|██▎       | 56/237 [00:22<01:12,  2.50it/s]

0.2
0.2
0.2
0.2
0.2


 24%|██▍       | 57/237 [00:23<01:09,  2.58it/s]

0.2
0.2
0.2
0.2
0.2


 24%|██▍       | 58/237 [00:23<01:07,  2.66it/s]

0.2
0.2
0.2
0.2
0.2


 25%|██▍       | 59/237 [00:23<01:07,  2.63it/s]

0.2
0.2
0.2
0.2
0.2


 25%|██▌       | 60/237 [00:24<01:04,  2.73it/s]

0.2
0.2
0.2
0.2
0.2


 26%|██▌       | 61/237 [00:24<01:04,  2.72it/s]

0.2
0.2
0.2
0.2
0.2


 26%|██▌       | 62/237 [00:24<01:03,  2.77it/s]

0.2
0.2
0.2
0.2
0.2


 27%|██▋       | 63/237 [00:25<01:15,  2.31it/s]

0.2
0.2
0.2
0.2
0.2


 27%|██▋       | 64/237 [00:25<01:11,  2.41it/s]

0.2
0.2
0.2
0.2
0.2


 27%|██▋       | 65/237 [00:26<01:07,  2.55it/s]

0.2
0.2
0.2
0.2
0.2


 28%|██▊       | 66/237 [00:26<01:06,  2.55it/s]

0.2
0.2
0.2
0.2
0.2


 28%|██▊       | 67/237 [00:27<01:04,  2.63it/s]

0.2
0.2
0.2
0.2
0.2


 29%|██▊       | 68/237 [00:27<01:03,  2.68it/s]

0.2
0.2
0.2
0.2
0.2


 29%|██▉       | 69/237 [00:27<01:01,  2.71it/s]

0.2
0.2
0.2
0.2
0.2


 30%|██▉       | 70/237 [00:28<01:02,  2.69it/s]

0.2
0.2
0.2
0.2
0.2


 30%|██▉       | 71/237 [00:28<01:00,  2.73it/s]

0.2
0.2
0.2
0.2
0.2


 30%|███       | 72/237 [00:28<00:59,  2.79it/s]

0.2
0.2
0.2
0.2
0.2


 31%|███       | 73/237 [00:29<00:58,  2.79it/s]

0.2
0.2
0.2
0.2
0.2


 31%|███       | 74/237 [00:29<01:09,  2.36it/s]

0.2
0.2
0.2
0.2
0.2


 32%|███▏      | 75/237 [00:30<01:06,  2.43it/s]

0.2
0.2
0.2
0.2
0.2


 32%|███▏      | 76/237 [00:30<01:03,  2.54it/s]

0.2
0.2
0.2
0.2
0.2


 32%|███▏      | 77/237 [00:30<01:01,  2.60it/s]

0.2
0.2
0.2
0.2
0.2


 33%|███▎      | 78/237 [00:31<00:59,  2.69it/s]

0.2
0.2
0.2
0.2
0.2


 33%|███▎      | 79/237 [00:31<00:58,  2.71it/s]

0.2
0.2
0.2
0.2
0.2


 34%|███▍      | 80/237 [00:31<00:57,  2.71it/s]

0.2
0.2
0.2
0.2
0.2


 34%|███▍      | 81/237 [00:32<00:57,  2.73it/s]

0.2
0.2
0.2
0.2
0.2


 35%|███▍      | 82/237 [00:32<00:57,  2.68it/s]

0.2
0.2
0.2
0.2
0.2


 35%|███▌      | 83/237 [00:33<00:56,  2.71it/s]

0.2
0.2
0.2
0.2
0.2


 35%|███▌      | 84/237 [00:33<00:55,  2.73it/s]

0.2
0.2
0.2
0.2
0.2


 36%|███▌      | 85/237 [00:34<01:10,  2.16it/s]

0.2
0.2
0.2
0.2
0.2


 36%|███▋      | 86/237 [00:34<01:06,  2.26it/s]

0.2
0.2
0.2
0.2
0.2


 37%|███▋      | 87/237 [00:34<01:02,  2.38it/s]

0.2
0.2
0.2
0.2
0.2


 37%|███▋      | 88/237 [00:35<01:02,  2.38it/s]

0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2


 38%|███▊      | 90/237 [00:36<01:09,  2.10it/s]

0.2
0.2
0.2
0.2
0.2


 38%|███▊      | 91/237 [00:36<01:03,  2.30it/s]

0.2
0.2
0.2
0.2
0.2


 39%|███▉      | 92/237 [00:37<01:03,  2.28it/s]

0.2
0.2
0.2
0.2
0.2


 39%|███▉      | 93/237 [00:37<01:00,  2.39it/s]

0.2
0.2
0.2
0.2
0.2


 40%|███▉      | 94/237 [00:37<00:57,  2.47it/s]

0.2
0.2
0.2
0.2
0.2


 40%|████      | 95/237 [00:38<00:55,  2.54it/s]

0.2
0.2
0.2
0.2
0.2


 41%|████      | 96/237 [00:38<01:06,  2.12it/s]

0.2
0.2
0.2
0.2
0.2


 41%|████      | 97/237 [00:39<01:02,  2.24it/s]

0.2
0.2
0.2
0.2
0.2


 41%|████▏     | 98/237 [00:39<00:58,  2.36it/s]

0.2
0.2
0.2
0.2
0.2


 42%|████▏     | 99/237 [00:40<00:55,  2.48it/s]

0.2
0.2
0.2
0.2
0.2


 42%|████▏     | 100/237 [00:40<00:53,  2.55it/s]

0.2
0.2
0.2
0.2
0.2


 43%|████▎     | 101/237 [00:40<00:53,  2.55it/s]

0.2
0.2
0.2
0.2
0.2


 43%|████▎     | 102/237 [00:41<00:51,  2.61it/s]

0.2
0.2
0.2
0.2
0.2


 43%|████▎     | 103/237 [00:41<00:54,  2.45it/s]

0.2
0.2
0.2
0.2
0.2


 44%|████▍     | 104/237 [00:42<00:52,  2.51it/s]

0.2
0.2
0.2
0.2
0.2


 44%|████▍     | 105/237 [00:42<00:54,  2.42it/s]

0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2


 45%|████▌     | 107/237 [00:43<01:08,  1.89it/s]

0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2


 46%|████▌     | 109/237 [00:44<01:10,  1.81it/s]

0.2
0.2
0.2
0.2
0.2


 46%|████▋     | 110/237 [00:45<01:02,  2.02it/s]

0.2
0.2
0.2
0.2
0.2


 47%|████▋     | 111/237 [00:45<00:58,  2.14it/s]

0.2
0.2
0.2
0.2
0.2


 47%|████▋     | 112/237 [00:46<00:55,  2.26it/s]

0.2
0.2
0.2
0.2
0.2


 48%|████▊     | 113/237 [00:46<00:50,  2.44it/s]

0.2
0.2
0.2
0.2
0.2


 48%|████▊     | 114/237 [00:46<00:49,  2.50it/s]

0.2
0.2
0.2
0.2
0.2


 49%|████▊     | 115/237 [00:47<00:47,  2.56it/s]

0.2
0.2
0.2
0.2
0.2


 49%|████▉     | 116/237 [00:47<00:46,  2.59it/s]

0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2


 50%|████▉     | 118/237 [00:48<00:55,  2.14it/s]

0.2
0.2
0.2
0.2
0.2


 50%|█████     | 119/237 [00:49<00:53,  2.19it/s]

0.2
0.2
0.2
0.2
0.2


 51%|█████     | 120/237 [00:49<00:54,  2.13it/s]

0.2
0.2
0.2
0.2
0.2


 51%|█████     | 121/237 [00:50<00:52,  2.19it/s]

0.2
0.2
0.2
0.2
0.2


 51%|█████▏    | 122/237 [00:50<00:53,  2.15it/s]

0.2
0.2
0.2
0.2
0.2


 52%|█████▏    | 123/237 [00:50<00:52,  2.16it/s]

0.2
0.2
0.2
0.2
0.2


 52%|█████▏    | 124/237 [00:51<00:52,  2.16it/s]

0.2
0.2
0.2
0.2
0.2


 53%|█████▎    | 125/237 [00:51<00:52,  2.12it/s]

0.2
0.2
0.2
0.2
0.2


 53%|█████▎    | 126/237 [00:52<00:52,  2.11it/s]

0.2
0.2
0.2
0.2
0.2


 54%|█████▎    | 127/237 [00:52<00:49,  2.24it/s]

0.2
0.2
0.2
0.2
0.2


 54%|█████▍    | 128/237 [00:53<01:00,  1.81it/s]

0.2
0.2
0.2
0.2
0.2


 54%|█████▍    | 129/237 [00:54<00:57,  1.87it/s]

0.2
0.2
0.2
0.2
0.2


 55%|█████▍    | 130/237 [00:54<00:57,  1.87it/s]

0.2
0.2
0.2
0.2
0.2


 55%|█████▌    | 131/237 [00:55<00:54,  1.96it/s]

0.2
0.2
0.2
0.2
0.2


 56%|█████▌    | 132/237 [00:55<00:53,  1.95it/s]

0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2


 56%|█████▌    | 133/237 [00:56<00:58,  1.78it/s]

0.2
0.2
0.2
0.2
0.2


 57%|█████▋    | 135/237 [00:57<01:03,  1.61it/s]

0.2
0.2
0.2
0.2
0.2


 57%|█████▋    | 136/237 [00:58<00:58,  1.74it/s]

0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2


 58%|█████▊    | 138/237 [00:59<00:56,  1.75it/s]

0.2
0.2
0.2
0.2
0.2


 59%|█████▊    | 139/237 [00:59<00:53,  1.82it/s]

0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2


 59%|█████▉    | 141/237 [01:01<01:04,  1.48it/s]

0.2
0.2
0.2
0.2
0.2


 60%|█████▉    | 142/237 [01:01<00:58,  1.63it/s]

0.2
0.2
0.2
0.2
0.2


 60%|██████    | 143/237 [01:02<00:58,  1.61it/s]

0.2
0.2
0.2
0.2
0.2


 61%|██████    | 144/237 [01:02<00:51,  1.79it/s]

0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2


 62%|██████▏   | 146/237 [01:03<00:48,  1.88it/s]

0.2
0.2
0.2
0.2
0.2


 62%|██████▏   | 147/237 [01:04<00:43,  2.07it/s]

0.2
0.2
0.2
0.2
0.2


 62%|██████▏   | 148/237 [01:04<00:39,  2.25it/s]

0.2
0.2
0.2
0.2
0.2


 63%|██████▎   | 149/237 [01:05<00:37,  2.33it/s]

0.2
0.2
0.2
0.2
0.2


 63%|██████▎   | 150/237 [01:05<00:36,  2.37it/s]

0.2
0.2
0.2
0.2
0.2


 64%|██████▎   | 151/237 [01:05<00:35,  2.42it/s]

0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2


 65%|██████▍   | 153/237 [01:07<00:41,  2.03it/s]

0.2
0.2
0.2
0.2
0.2


 65%|██████▍   | 154/237 [01:07<00:37,  2.21it/s]

0.2
0.2
0.2
0.2
0.2


 65%|██████▌   | 155/237 [01:07<00:35,  2.33it/s]

0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2


 66%|██████▌   | 157/237 [01:09<00:42,  1.87it/s]

0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2


 67%|██████▋   | 159/237 [01:10<00:40,  1.91it/s]

0.2
0.2
0.2
0.2
0.2


 68%|██████▊   | 160/237 [01:10<00:38,  2.01it/s]

0.2
0.2
0.2
0.2
0.2


 68%|██████▊   | 161/237 [01:10<00:35,  2.15it/s]

0.2
0.2
0.2
0.2
0.2


 68%|██████▊   | 162/237 [01:11<00:33,  2.22it/s]

0.2
0.2
0.2
0.2
0.2


 69%|██████▉   | 163/237 [01:11<00:31,  2.34it/s]

0.2
0.2
0.2
0.2
0.2


 69%|██████▉   | 164/237 [01:12<00:30,  2.37it/s]

0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2


 70%|███████   | 166/237 [01:13<00:40,  1.75it/s]

0.2
0.2
0.2
0.2
0.2


 70%|███████   | 167/237 [01:13<00:35,  1.98it/s]

0.2
0.2
0.2
0.2
0.2


 71%|███████   | 168/237 [01:14<00:32,  2.14it/s]

0.2
0.2
0.2
0.2
0.2


 71%|███████▏  | 169/237 [01:14<00:31,  2.14it/s]

0.2
0.2
0.2
0.2
0.2


 72%|███████▏  | 170/237 [01:15<00:31,  2.14it/s]

0.2
0.2
0.2
0.2
0.2


 72%|███████▏  | 171/237 [01:15<00:29,  2.23it/s]

0.2
0.2
0.2
0.2
0.2


 73%|███████▎  | 172/237 [01:16<00:28,  2.25it/s]

0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2


 73%|███████▎  | 173/237 [01:16<00:34,  1.86it/s]

0.2
0.2
0.2
0.2
0.2


 74%|███████▍  | 175/237 [01:17<00:31,  1.98it/s]

0.2
0.2
0.2
0.2
0.2


 74%|███████▍  | 176/237 [01:18<00:30,  1.99it/s]

0.2
0.2
0.2
0.2
0.2


 75%|███████▍  | 177/237 [01:18<00:28,  2.09it/s]

0.2
0.2
0.2
0.2
0.2


 75%|███████▌  | 178/237 [01:19<00:33,  1.78it/s]

0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2


 76%|███████▌  | 180/237 [01:20<00:29,  1.96it/s]

0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2


 77%|███████▋  | 182/237 [01:21<00:29,  1.86it/s]

0.2
0.2
0.2
0.2
0.2


 77%|███████▋  | 183/237 [01:22<00:28,  1.87it/s]

0.2
0.2
0.2
0.2
0.2


 78%|███████▊  | 184/237 [01:22<00:28,  1.88it/s]

0.2
0.2
0.2
0.2
0.2


 78%|███████▊  | 185/237 [01:23<00:26,  1.93it/s]

0.2
0.2
0.2
0.2
0.2


 78%|███████▊  | 186/237 [01:23<00:30,  1.65it/s]

0.2
0.2
0.2
0.2
0.2


 79%|███████▉  | 187/237 [01:24<00:28,  1.74it/s]

0.2
0.2
0.2
0.2
0.2


 79%|███████▉  | 188/237 [01:24<00:25,  1.94it/s]

0.2
0.2
0.2
0.2
0.2


 80%|███████▉  | 189/237 [01:25<00:24,  1.93it/s]

0.2
0.2
0.2
0.2
0.2


 80%|████████  | 190/237 [01:25<00:23,  2.00it/s]

0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2


 81%|████████  | 191/237 [01:26<00:24,  1.89it/s]

0.2
0.2
0.2
0.2
0.2


 81%|████████  | 192/237 [01:27<00:24,  1.84it/s]

0.2
0.2
0.2
0.2
0.2


 82%|████████▏ | 194/237 [01:28<00:25,  1.71it/s]

0.2
0.2
0.2
0.2
0.2


 82%|████████▏ | 195/237 [01:28<00:22,  1.83it/s]

0.2
0.2
0.2
0.2
0.2


 83%|████████▎ | 196/237 [01:29<00:21,  1.92it/s]

0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2


 84%|████████▎ | 198/237 [01:30<00:24,  1.56it/s]

0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2


 84%|████████▍ | 199/237 [01:31<00:23,  1.60it/s]

0.2
0.2
0.2
0.2
0.2


 85%|████████▍ | 201/237 [01:32<00:21,  1.66it/s]

0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2


 86%|████████▌ | 203/237 [01:33<00:17,  1.89it/s]

0.2
0.2
0.2
0.2
0.2


 86%|████████▌ | 204/237 [01:33<00:16,  1.98it/s]

0.2
0.2
0.2
0.2
0.2


 86%|████████▋ | 205/237 [01:34<00:15,  2.06it/s]

0.2
0.2
0.2
0.2
0.2


 87%|████████▋ | 206/237 [01:34<00:14,  2.12it/s]

0.2
0.2
0.2
0.2
0.2


 87%|████████▋ | 207/237 [01:35<00:14,  2.12it/s]

0.2
0.2
0.2
0.2
0.2


 88%|████████▊ | 208/237 [01:35<00:13,  2.19it/s]

0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2


 89%|████████▊ | 210/237 [01:37<00:15,  1.71it/s]

0.2
0.2
0.2
0.2
0.2


 89%|████████▉ | 211/237 [01:37<00:13,  1.92it/s]

0.2
0.2
0.2
0.2
0.2


 89%|████████▉ | 212/237 [01:37<00:12,  2.00it/s]

0.2
0.2
0.2
0.2
0.2


 90%|████████▉ | 213/237 [01:38<00:12,  1.98it/s]

0.2
0.2
0.2
0.2
0.2


 90%|█████████ | 214/237 [01:38<00:11,  1.98it/s]

0.2
0.2
0.2
0.2
0.2


 91%|█████████ | 215/237 [01:39<00:10,  2.04it/s]

0.2
0.2
0.2
0.2
0.2


 91%|█████████ | 216/237 [01:39<00:09,  2.18it/s]

0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2


 92%|█████████▏| 218/237 [01:40<00:09,  2.09it/s]

0.2
0.2
0.2
0.2
0.2


 92%|█████████▏| 219/237 [01:41<00:08,  2.15it/s]

0.2
0.2
0.2
0.2
0.2


 93%|█████████▎| 220/237 [01:41<00:07,  2.18it/s]

0.2
0.2
0.2
0.2
0.2


 93%|█████████▎| 221/237 [01:41<00:07,  2.28it/s]

0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2


 94%|█████████▎| 222/237 [01:42<00:08,  1.82it/s]

0.2
0.2
0.2
0.2
0.2


 95%|█████████▍| 224/237 [01:43<00:06,  1.88it/s]

0.2
0.2
0.2
0.2
0.2


 95%|█████████▍| 225/237 [01:44<00:06,  1.93it/s]

0.2
0.2
0.2
0.2
0.2


 95%|█████████▌| 226/237 [01:44<00:05,  1.97it/s]

0.2
0.2
0.2
0.2
0.2


 96%|█████████▌| 227/237 [01:45<00:05,  1.99it/s]

0.2
0.2
0.2
0.2
0.2


 96%|█████████▌| 228/237 [01:45<00:04,  2.02it/s]

0.2
0.2
0.2
0.2
0.2


 97%|█████████▋| 229/237 [01:46<00:03,  2.08it/s]

0.2
0.2
0.2
0.2
0.2


 97%|█████████▋| 230/237 [01:46<00:03,  2.10it/s]

0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2


 98%|█████████▊| 232/237 [01:47<00:02,  2.06it/s]

0.2
0.2
0.2
0.2
0.2


 98%|█████████▊| 233/237 [01:48<00:01,  2.11it/s]

0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2


 99%|█████████▊| 234/237 [01:48<00:01,  1.76it/s]

0.2
0.2
0.2
0.2
0.2


100%|█████████▉| 236/237 [01:50<00:00,  1.82it/s]

0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2
0.2


                                                 

186.50392541201307
[0.18020093197804912, 0.17267479180214965, 0.17424308409785755, 0.2812554344448441, 0.19162575767709947]
Train loss: 15.64419
Task 1 - Train Loss: 33.60818, Train Accuracy: 100.00%, Train Balanced Accuracy: 100.00%
Task 2 - Train Loss: 32.20453, Train Accuracy: 73.47%, Train Balanced Accuracy: 50.00%
Task 3 - Train Loss: 32.49702, Train Accuracy: 57.85%, Train Balanced Accuracy: 50.00%
Task 4 - Train Loss: 52.45524, Train Accuracy: 56.55%, Train Balanced Accuracy: 50.10%
Task 5 - Train Loss: 35.73896, Train Accuracy: 51.05%, Train Balanced Accuracy: 51.18%
Final Output - Train Accuracy: 42.81%, Train Balanced Accuracy: 49.38%, Train F1 Score: 50.30%


                                               

Test loss: 20.89410
Task 1 - Test Accuracy: 100.00%, Test Balanced Accuracy: 100.00%
Task 2 - Test Accuracy: 71.98%, Test Balanced Accuracy: 50.00%
Task 3 - Test Accuracy: 56.90%, Test Balanced Accuracy: 50.00%
Task 4 - Test Accuracy: 20.98%, Test Balanced Accuracy: 50.00%
Task 5 - Test Accuracy: 48.04%, Test Balanced Accuracy: 50.00%
Final Output - Test Accuracy: 36.93%, Test Balanced Accuracy: 50.00%

----------------------------------------------------------------------------------------------------
Epoch: 2/10,	Task Weights: ['0.22', '0.24', '0.24', '0.09', '0.20']
----------------------------------------------------------------------------------------------------


  0%|          | 1/237 [00:00<01:36,  2.44it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


  1%|          | 2/237 [00:00<01:40,  2.35it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


  1%|▏         | 3/237 [00:01<01:53,  2.06it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423
0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


  2%|▏         | 5/237 [00:03<02:32,  1.52it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


  3%|▎         | 6/237 [00:03<02:21,  1.64it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


  3%|▎         | 7/237 [00:04<02:13,  1.73it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


  3%|▎         | 8/237 [00:04<02:11,  1.74it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


  4%|▍         | 9/237 [00:05<02:11,  1.74it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


  4%|▍         | 10/237 [00:05<02:03,  1.84it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


  5%|▍         | 11/237 [00:06<02:01,  1.86it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


  5%|▌         | 12/237 [00:06<02:00,  1.86it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


  5%|▌         | 13/237 [00:07<01:57,  1.90it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423
0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


  6%|▋         | 15/237 [00:08<02:22,  1.56it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


  7%|▋         | 16/237 [00:09<02:13,  1.66it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


  7%|▋         | 17/237 [00:09<02:04,  1.77it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423
0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


  8%|▊         | 18/237 [00:10<02:05,  1.74it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


  8%|▊         | 19/237 [00:11<02:07,  1.71it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


  8%|▊         | 20/237 [00:11<02:09,  1.67it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


  9%|▉         | 22/237 [00:12<02:09,  1.66it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 10%|▉         | 23/237 [00:13<01:59,  1.79it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 10%|█         | 24/237 [00:13<01:55,  1.85it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 11%|█         | 25/237 [00:14<01:50,  1.93it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423
0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 11%|█▏        | 27/237 [00:15<02:03,  1.70it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 12%|█▏        | 28/237 [00:16<01:56,  1.79it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 12%|█▏        | 29/237 [00:16<01:55,  1.80it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 13%|█▎        | 30/237 [00:17<01:49,  1.90it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 13%|█▎        | 31/237 [00:17<01:47,  1.92it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 14%|█▎        | 32/237 [00:18<01:43,  1.98it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 14%|█▍        | 33/237 [00:18<01:41,  2.01it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 14%|█▍        | 34/237 [00:19<01:40,  2.01it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 15%|█▍        | 35/237 [00:19<01:38,  2.05it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423
0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 16%|█▌        | 37/237 [00:21<02:01,  1.64it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 16%|█▌        | 38/237 [00:21<01:53,  1.76it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 16%|█▋        | 39/237 [00:22<01:48,  1.82it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 17%|█▋        | 40/237 [00:22<01:44,  1.88it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 17%|█▋        | 41/237 [00:23<01:40,  1.94it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 18%|█▊        | 42/237 [00:23<01:38,  1.99it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 18%|█▊        | 43/237 [00:24<01:35,  2.04it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 19%|█▊        | 44/237 [00:24<01:29,  2.16it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 19%|█▉        | 45/237 [00:24<01:27,  2.20it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 19%|█▉        | 46/237 [00:25<01:26,  2.21it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 20%|█▉        | 47/237 [00:25<01:22,  2.31it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423
0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 21%|██        | 49/237 [00:27<01:40,  1.87it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 21%|██        | 50/237 [00:27<01:31,  2.03it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 22%|██▏       | 51/237 [00:27<01:28,  2.10it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 22%|██▏       | 52/237 [00:28<01:26,  2.13it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 22%|██▏       | 53/237 [00:28<01:21,  2.25it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 23%|██▎       | 54/237 [00:29<01:22,  2.22it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 23%|██▎       | 55/237 [00:29<01:24,  2.15it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 24%|██▎       | 56/237 [00:30<01:21,  2.21it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 24%|██▍       | 57/237 [00:30<01:20,  2.23it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 24%|██▍       | 58/237 [00:31<01:22,  2.16it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 25%|██▍       | 59/237 [00:31<01:20,  2.21it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 25%|██▌       | 60/237 [00:31<01:20,  2.19it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423
0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 26%|██▌       | 62/237 [00:32<01:19,  2.21it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 27%|██▋       | 63/237 [00:33<01:43,  1.68it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423
0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 27%|██▋       | 65/237 [00:34<01:30,  1.91it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 28%|██▊       | 66/237 [00:35<01:26,  1.99it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 28%|██▊       | 67/237 [00:35<01:21,  2.10it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423
0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 29%|██▉       | 69/237 [00:36<01:17,  2.18it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 30%|██▉       | 70/237 [00:36<01:13,  2.27it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 30%|██▉       | 71/237 [00:37<01:15,  2.21it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 30%|███       | 72/237 [00:37<01:10,  2.35it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 31%|███       | 73/237 [00:38<01:09,  2.36it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 31%|███       | 74/237 [00:38<01:12,  2.25it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 32%|███▏      | 75/237 [00:38<01:08,  2.37it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 32%|███▏      | 76/237 [00:39<01:10,  2.29it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 32%|███▏      | 77/237 [00:39<01:14,  2.14it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 33%|███▎      | 78/237 [00:40<01:15,  2.10it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 33%|███▎      | 79/237 [00:41<01:17,  2.03it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423
0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 34%|███▍      | 80/237 [00:41<01:19,  1.98it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 34%|███▍      | 81/237 [00:42<01:19,  1.97it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 35%|███▍      | 82/237 [00:42<01:20,  1.92it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 35%|███▌      | 83/237 [00:43<01:52,  1.36it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 35%|███▌      | 84/237 [00:44<01:47,  1.42it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 36%|███▋      | 86/237 [00:45<01:33,  1.62it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423
0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 37%|███▋      | 87/237 [00:46<01:30,  1.66it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 37%|███▋      | 88/237 [00:46<01:26,  1.72it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 38%|███▊      | 89/237 [00:47<01:24,  1.75it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 38%|███▊      | 90/237 [00:47<01:26,  1.70it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 38%|███▊      | 91/237 [00:48<01:29,  1.63it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 39%|███▉      | 92/237 [00:49<01:25,  1.70it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 39%|███▉      | 93/237 [00:49<01:24,  1.70it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 40%|███▉      | 94/237 [00:50<01:25,  1.67it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 41%|████      | 96/237 [00:51<01:21,  1.73it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423
0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 41%|████      | 97/237 [00:52<01:24,  1.65it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 41%|████▏     | 98/237 [00:52<01:26,  1.60it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 42%|████▏     | 100/237 [00:53<01:16,  1.80it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423
0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 43%|████▎     | 102/237 [00:54<01:11,  1.90it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 43%|████▎     | 103/237 [00:55<01:07,  1.97it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423
0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 44%|████▍     | 105/237 [00:56<01:02,  2.13it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 45%|████▍     | 106/237 [00:56<01:02,  2.09it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 45%|████▌     | 107/237 [00:56<00:59,  2.18it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 46%|████▌     | 108/237 [00:57<01:02,  2.08it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 46%|████▌     | 109/237 [00:57<01:00,  2.11it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 46%|████▋     | 110/237 [00:58<00:59,  2.14it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 47%|████▋     | 111/237 [00:58<01:00,  2.07it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423
0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 48%|████▊     | 113/237 [00:59<01:00,  2.06it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 48%|████▊     | 114/237 [01:00<00:57,  2.15it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423
0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 49%|████▉     | 116/237 [01:01<00:55,  2.18it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423
0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 50%|████▉     | 118/237 [01:02<00:55,  2.13it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 50%|█████     | 119/237 [01:02<00:53,  2.19it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 51%|█████     | 120/237 [01:03<00:54,  2.15it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 51%|█████     | 121/237 [01:03<00:52,  2.22it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423
0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 51%|█████▏    | 122/237 [01:04<00:53,  2.15it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 52%|█████▏    | 124/237 [01:04<00:53,  2.12it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423
0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 53%|█████▎    | 125/237 [01:05<00:53,  2.10it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 53%|█████▎    | 126/237 [01:06<01:15,  1.47it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 54%|█████▍    | 128/237 [01:07<01:02,  1.73it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 54%|█████▍    | 129/237 [01:08<00:59,  1.83it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423
0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 55%|█████▍    | 130/237 [01:08<00:55,  1.92it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 55%|█████▌    | 131/237 [01:09<00:55,  1.92it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 56%|█████▌    | 132/237 [01:09<00:58,  1.80it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 56%|█████▌    | 133/237 [01:10<00:59,  1.74it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 57%|█████▋    | 135/237 [01:11<00:57,  1.76it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423
0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 57%|█████▋    | 136/237 [01:11<00:55,  1.83it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 58%|█████▊    | 138/237 [01:12<00:49,  2.01it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423
0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 59%|█████▉    | 140/237 [01:13<00:51,  1.87it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 59%|█████▉    | 141/237 [01:14<00:50,  1.90it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 60%|█████▉    | 142/237 [01:14<00:48,  1.95it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423
0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 60%|██████    | 143/237 [01:15<00:48,  1.94it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 61%|██████    | 145/237 [01:16<00:45,  2.04it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423
0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 62%|██████▏   | 146/237 [01:16<00:45,  2.00it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 62%|██████▏   | 148/237 [01:17<00:44,  1.99it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423
0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 63%|██████▎   | 149/237 [01:18<00:48,  1.81it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 63%|██████▎   | 150/237 [01:19<01:08,  1.27it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 64%|██████▎   | 151/237 [01:20<01:03,  1.36it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 64%|██████▍   | 152/237 [01:21<01:01,  1.38it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 65%|██████▍   | 153/237 [01:22<01:02,  1.33it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 65%|██████▍   | 154/237 [01:22<01:01,  1.35it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 65%|██████▌   | 155/237 [01:23<00:59,  1.37it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 66%|██████▌   | 156/237 [01:25<01:19,  1.03it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 66%|██████▌   | 157/237 [01:25<01:09,  1.15it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 67%|██████▋   | 158/237 [01:26<01:13,  1.08it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 67%|██████▋   | 159/237 [01:28<01:21,  1.05s/it]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 68%|██████▊   | 160/237 [01:28<01:10,  1.09it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 68%|██████▊   | 162/237 [01:29<00:57,  1.31it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423
0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 69%|██████▉   | 164/237 [01:30<00:46,  1.57it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423
0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 70%|██████▉   | 165/237 [01:31<00:45,  1.57it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 70%|███████   | 167/237 [01:32<00:45,  1.54it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423
0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 71%|███████   | 168/237 [01:33<00:45,  1.53it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 71%|███████▏  | 169/237 [01:34<00:46,  1.46it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 72%|███████▏  | 170/237 [01:35<00:46,  1.44it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 72%|███████▏  | 171/237 [01:36<00:59,  1.12it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 73%|███████▎  | 172/237 [01:37<01:01,  1.05it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 73%|███████▎  | 173/237 [01:38<00:59,  1.07it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 73%|███████▎  | 174/237 [01:39<00:57,  1.09it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 74%|███████▍  | 175/237 [01:41<01:21,  1.31s/it]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 74%|███████▍  | 176/237 [01:43<01:29,  1.46s/it]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 75%|███████▌  | 178/237 [01:44<01:02,  1.07s/it]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 76%|███████▌  | 179/237 [01:45<00:53,  1.09it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423
0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 76%|███████▌  | 180/237 [01:46<00:48,  1.17it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 76%|███████▋  | 181/237 [01:46<00:48,  1.15it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 77%|███████▋  | 183/237 [01:48<00:37,  1.44it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 78%|███████▊  | 184/237 [01:48<00:34,  1.54it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423
0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 78%|███████▊  | 185/237 [01:49<00:40,  1.27it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 79%|███████▉  | 187/237 [01:50<00:33,  1.51it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 79%|███████▉  | 188/237 [01:51<00:29,  1.64it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423
0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 80%|███████▉  | 189/237 [01:51<00:27,  1.72it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 80%|████████  | 190/237 [01:52<00:35,  1.34it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 81%|████████  | 192/237 [01:54<00:35,  1.27it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423
0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 82%|████████▏ | 194/237 [01:55<00:28,  1.52it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423
0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 82%|████████▏ | 195/237 [01:56<00:25,  1.62it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 83%|████████▎ | 196/237 [01:57<00:29,  1.41it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 83%|████████▎ | 197/237 [01:58<00:36,  1.10it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 84%|████████▎ | 198/237 [01:59<00:34,  1.12it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 84%|████████▍ | 199/237 [02:00<00:31,  1.22it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


 84%|████████▍ | 200/237 [02:00<00:29,  1.23it/s]

0.22454533796516998
0.24454513922564897
0.24016298841892628
0.09217733478361258
0.1985691996066423


                                                 

KeyboardInterrupt: 

In [26]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os
import copy
import time

from src.utils.receptive_field import compute_rf_prototype
from src.utils.helpers import makedir, find_high_activation_crop

# push each prototype to the nearest patch in the training set
def push_prototypes(dataloader, # pytorch dataloader (must be unnormalized in [0,1])
                    prototype_network_parallel, # pytorch network with prototype_vectors
                    class_specific=True,
                    preprocess_input_function=None, # normalize if needed
                    prototype_layer_stride=1,
                    root_dir_for_saving_prototypes=None, # if not None, prototypes will be saved here
                    epoch_number=None, # if not provided, prototypes saved previously will be overwritten
                    prototype_img_filename_prefix=None,
                    prototype_self_act_filename_prefix=None,
                    proto_bound_boxes_filename_prefix=None,
                    save_prototype_class_identity=True, # which class the prototype image comes from
                    prototype_activation_function_in_numpy=None):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    prototype_network_parallel.to(device).eval()

    # Assuming prototype_network_parallel.module.prototype_vectors is a list of tensors
    all_prototype_shapes = [prototypes.shape for prototypes in prototype_network_parallel.prototype_vectors]
    all_n_prototypes = [shape[0] for shape in all_prototype_shapes]
    
    # Saves the closest distance to the prototype
    all_global_min_proto_dist = [np.full(n_prototypes, np.inf) for n_prototypes in all_n_prototypes]
    
    # Saves the patch that minimizes the distance to the prototype
    all_global_min_fmap_patches = [np.zeros([n_prototypes, shape[1], shape[2], shape[3]]) 
                                   for n_prototypes, shape in zip(all_n_prototypes, all_prototype_shapes)]
    
    # Assuming the same bounding box and class identity handling applies to all characteristics
    # Initialize proto_rf_boxes and proto_bound_boxes with appropriate shapes
    if save_prototype_class_identity:
        proto_rf_boxes = [np.full([n_prototypes, 6], -1) for n_prototypes in all_n_prototypes]
        proto_bound_boxes = [np.full([n_prototypes, 6], -1) for n_prototypes in all_n_prototypes]
    else:
        proto_rf_boxes = [np.full([n_prototypes, 5], -1) for n_prototypes in all_n_prototypes]
        proto_bound_boxes = [np.full([n_prototypes, 5], -1) for n_prototypes in all_n_prototypes]

    # Create a directory to save the prototypes
    if root_dir_for_saving_prototypes != None:
        if epoch_number != None:
            proto_epoch_dir = os.path.join(root_dir_for_saving_prototypes,
                                           'epoch-'+str(epoch_number))
            makedir(proto_epoch_dir)
        else:
            proto_epoch_dir = root_dir_for_saving_prototypes
    else:
        proto_epoch_dir = None

    search_batch_size = dataloader.batch_size
    num_classes = prototype_network_parallel.num_classes

    for push_iter, (search_batch_input, search_y_chars, _, _, _) in enumerate(dataloader):
        search_batch_size = search_batch_input.shape[0]
        start_index_of_search_batch = push_iter * search_batch_size

        # Handle batch processing for each characteristic
        for characteristic_index, (global_min_proto_dist, global_min_fmap_patches, proto_rf_box, proto_bound_box) in enumerate(zip(all_global_min_proto_dist, all_global_min_fmap_patches, proto_rf_boxes, proto_bound_boxes)):
            update_prototypes_on_batch(search_batch_input, 
                                       start_index_of_search_batch,
                                       prototype_network_parallel, 
                                       global_min_proto_dist,
                                       global_min_fmap_patches, 
                                       proto_rf_box, 
                                       proto_bound_box,
                                       class_specific=class_specific, 
                                       search_y=search_y_chars[characteristic_index],
                                       num_classes=prototype_network_parallel.num_classes,
                                       preprocess_input_function=preprocess_input_function,
                                       prototype_layer_stride=prototype_layer_stride,
                                       dir_for_saving_prototypes=root_dir_for_saving_prototypes,
                                       prototype_img_filename_prefix=prototype_img_filename_prefix,
                                       prototype_self_act_filename_prefix=prototype_self_act_filename_prefix,
                                       prototype_activation_function_in_numpy=prototype_activation_function_in_numpy,
                                       characteristic_index=characteristic_index)  # Pass characteristic_index to handle each set individually

    # Save bounding boxes and receptive field information for each set of prototypes
    if root_dir_for_saving_prototypes and epoch_number is not None:
        for idx, (proto_rf_box, proto_bound_box) in enumerate(zip(proto_rf_boxes, proto_bound_boxes)):
            np.save(os.path.join(root_dir_for_saving_prototypes, f"{proto_bound_boxes_filename_prefix}_receptive_field_characteristic_{idx}_epoch_{epoch_number}.npy"), proto_rf_box)
            np.save(os.path.join(root_dir_for_saving_prototypes, f"{proto_bound_boxes_filename_prefix}_characteristic_{idx}_epoch_{epoch_number}.npy"), proto_bound_box)

    # Update prototype vectors for each characteristic
    for idx, (prototype_vectors, global_min_fmap_patches) in enumerate(zip(prototype_network_parallel.prototype_vectors, all_global_min_fmap_patches)):
        prototype_update = np.reshape(global_min_fmap_patches, prototype_vectors.shape)
        prototype_vectors.data.copy_(torch.tensor(prototype_update, dtype=torch.float32).to(device))


# update each prototype for current search batch
def update_prototypes_on_batch(search_batch_input,
                               start_index_of_search_batch,
                               prototype_network_parallel,
                               global_min_proto_dist, # this will be updated
                               global_min_fmap_patches, # this will be updated
                               proto_rf_boxes, # this will be updated
                               proto_bound_boxes, # this will be updated
                               class_specific=True,
                               search_y=None, # required if class_specific == True
                               num_classes=None, # required if class_specific == True
                               preprocess_input_function=None,
                               prototype_layer_stride=1,
                               dir_for_saving_prototypes=None,
                               prototype_img_filename_prefix=None,
                               prototype_self_act_filename_prefix=None,
                               prototype_activation_function_in_numpy=None,
                               characteristic_index=0):

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    prototype_network_parallel.eval()

    # Preprocess the search batch
    if preprocess_input_function is not None:
        search_batch = preprocess_input_function(search_batch_input)
    else:
        search_batch = search_batch_input
        
    # Push the search batch through the network
    with torch.no_grad():
        search_batch = search_batch.to(device)
        protoL_input_torch, proto_dist_torch = prototype_network_parallel.push_forward(search_batch) # push the batch through the network
    
    # Send the data back to the cpu
    protoL_input_ = np.copy(protoL_input_torch.detach().cpu().numpy())  # (batch_size, C, H, W)
    proto_dist_ = np.copy(proto_dist_torch[characteristic_index].detach().cpu().numpy())    # (batch_size, num_characteristics, num_prototypes_per_characteristic, H, W)

    del protoL_input_torch, proto_dist_torch

    # Initialize class_to_img_index_dict
    if class_specific:
        class_to_img_index_dict = {key: [] for key in range(num_classes)}
        for img_index, img_y in enumerate(search_y):
            img_label = img_y.item()
            class_to_img_index_dict[img_label].append(img_index)
            
    # print(class_to_img_index_dict)

    prototype_shape = prototype_network_parallel.prototype_shape
    n_prototypes = prototype_network_parallel.prototypes_per_characteristic
    proto_h = prototype_shape[2]
    proto_w = prototype_shape[3]
    max_dist = prototype_shape[1] * prototype_shape[2] * prototype_shape[3] # C * H * W

    for j in range(n_prototypes):
        if class_specific:
            # target_class is the class of the class_specific prototype
            target_class = torch.argmax(prototype_network_parallel.prototype_class_identity[characteristic_index][j]).item()
            # if there is not images of the target_class from this batch
            # we go on to the next prototype
            if len(class_to_img_index_dict[target_class]) == 0:
                continue
            proto_dist_j = proto_dist_[class_to_img_index_dict[target_class]][:,j,:,:]
        else:
            # if it is not class specific, then we will search through every example
            proto_dist_j = proto_dist_[:,j,:,:]

        batch_min_proto_dist_j = np.amin(proto_dist_j)
        
        if batch_min_proto_dist_j < global_min_proto_dist[j]:
            batch_argmin_proto_dist_j = \
                list(np.unravel_index(np.argmin(proto_dist_j, axis=None),
                                      proto_dist_j.shape))
            if class_specific:
                '''
                change the argmin index from the index among
                images of the target class to the index in the entire search
                batch
                '''
                batch_argmin_proto_dist_j[0] = class_to_img_index_dict[target_class][batch_argmin_proto_dist_j[0]]

            # retrieve the corresponding feature map patch
            img_index_in_batch = batch_argmin_proto_dist_j[0]
            fmap_height_start_index = batch_argmin_proto_dist_j[1] * prototype_layer_stride
            fmap_height_end_index = fmap_height_start_index + proto_h
            fmap_width_start_index = batch_argmin_proto_dist_j[2] * prototype_layer_stride
            fmap_width_end_index = fmap_width_start_index + proto_w

            batch_min_fmap_patch_j = protoL_input_[img_index_in_batch,
                                                   :,
                                                   fmap_height_start_index:fmap_height_end_index,
                                                   fmap_width_start_index:fmap_width_end_index]

            global_min_proto_dist[j] = batch_min_proto_dist_j
            global_min_fmap_patches[j] = batch_min_fmap_patch_j
            
            # get the receptive field boundary of the image patch
            # that generates the representation
            protoL_rf_info = prototype_network_parallel.proto_layer_rf_info
            rf_prototype_j = compute_rf_prototype(search_batch.size(2), batch_argmin_proto_dist_j, protoL_rf_info)
            
            # get the whole image
            original_img_j = search_batch_input[rf_prototype_j[0]]
            original_img_j = original_img_j.numpy()
            original_img_j = np.transpose(original_img_j, (1, 2, 0))
            original_img_size = original_img_j.shape[0]
            
            # crop out the receptive field
            rf_img_j = original_img_j[rf_prototype_j[1]:rf_prototype_j[2],
                                      rf_prototype_j[3]:rf_prototype_j[4], :]
            
            # save the prototype receptive field information
            proto_rf_boxes[j, 0] = rf_prototype_j[0] + start_index_of_search_batch
            proto_rf_boxes[j, 1] = rf_prototype_j[1]
            proto_rf_boxes[j, 2] = rf_prototype_j[2]
            proto_rf_boxes[j, 3] = rf_prototype_j[3]
            proto_rf_boxes[j, 4] = rf_prototype_j[4]
            if proto_rf_boxes.shape[1] == 6 and search_y is not None:
                proto_rf_boxes[j, 5] = search_y[rf_prototype_j[0]].item()

            # find the highly activated region of the original image
            proto_dist_img_j = proto_dist_[img_index_in_batch, j, :, :]
            if prototype_network_parallel.prototype_activation_function == 'log':
                proto_act_img_j = np.log((proto_dist_img_j + 1) / (proto_dist_img_j + prototype_network_parallel.epsilon))
            elif prototype_network_parallel.prototype_activation_function == 'linear':
                proto_act_img_j = max_dist - proto_dist_img_j
            else:
                proto_act_img_j = prototype_activation_function_in_numpy(proto_dist_img_j)
            upsampled_act_img_j = cv2.resize(proto_act_img_j, dsize=(original_img_size, original_img_size),
                                             interpolation=cv2.INTER_CUBIC)
            proto_bound_j = find_high_activation_crop(upsampled_act_img_j)
            # crop out the image patch with high activation as prototype image
            proto_img_j = original_img_j[proto_bound_j[0]:proto_bound_j[1],
                                         proto_bound_j[2]:proto_bound_j[3], :]

            # save the prototype boundary (rectangular boundary of highly activated region)
            proto_bound_boxes[j, 0] = proto_rf_boxes[j, 0]
            proto_bound_boxes[j, 1] = proto_bound_j[0]
            proto_bound_boxes[j, 2] = proto_bound_j[1]
            proto_bound_boxes[j, 3] = proto_bound_j[2]
            proto_bound_boxes[j, 4] = proto_bound_j[3]
            
            if proto_bound_boxes.shape[1] == 6 and search_y is not None:
                proto_bound_boxes[j, 5] = search_y[rf_prototype_j[0]].item()

            if dir_for_saving_prototypes is not None:
                if prototype_self_act_filename_prefix is not None:
                    # save the numpy array of the prototype self activation
                    np.save(os.path.join(dir_for_saving_prototypes,
                                         prototype_self_act_filename_prefix + str(j) + '.npy'),
                            proto_act_img_j)
                if prototype_img_filename_prefix is not None:
                    # save the whole image containing the prototype as png
                    plt.imsave(os.path.join(dir_for_saving_prototypes,
                                            prototype_img_filename_prefix + '-original' + str(j) + '.png'),
                               original_img_j,
                               vmin=0.0,
                               vmax=1.0)
                    # overlay (upsampled) self activation on original image and save the result
                    rescaled_act_img_j = upsampled_act_img_j - np.amin(upsampled_act_img_j)
                    rescaled_act_img_j = rescaled_act_img_j / np.amax(rescaled_act_img_j)
                    heatmap = cv2.applyColorMap(np.uint8(255*rescaled_act_img_j), cv2.COLORMAP_JET)
                    heatmap = np.float32(heatmap) / 255
                    heatmap = heatmap[...,::-1]
                    overlayed_original_img_j = 0.5 * original_img_j + 0.3 * heatmap
                    plt.imsave(os.path.join(dir_for_saving_prototypes,
                                            prototype_img_filename_prefix + '-original_with_self_act' + str(j) + '.png'),
                               overlayed_original_img_j,
                               vmin=0.0,
                               vmax=1.0)
                    
                    # if different from the original (whole) image, save the prototype receptive field as png
                    if rf_img_j.shape[0] != original_img_size or rf_img_j.shape[1] != original_img_size:
                        plt.imsave(os.path.join(dir_for_saving_prototypes,
                                                prototype_img_filename_prefix + '-receptive_field' + str(j) + '.png'),
                                   rf_img_j,
                                   vmin=0.0,
                                   vmax=1.0)
                        overlayed_rf_img_j = overlayed_original_img_j[rf_prototype_j[1]:rf_prototype_j[2],
                                                                      rf_prototype_j[3]:rf_prototype_j[4]]
                        plt.imsave(os.path.join(dir_for_saving_prototypes,
                                                prototype_img_filename_prefix + '-receptive_field_with_self_act' + str(j) + '.png'),
                                   overlayed_rf_img_j,
                                   vmin=0.0,
                                   vmax=1.0)
                    
                    # save the prototype image (highly activated region of the whole image)
                    plt.imsave(os.path.join(dir_for_saving_prototypes,
                                            prototype_img_filename_prefix + str(j) + '.png'),
                               proto_img_j,
                               vmin=0.0,
                               vmax=1.0)
                
    if class_specific:
        del class_to_img_index_dict

In [27]:
epochs = 1

for epoch in range(epochs):
    push_prototypes(train_dataloader, 
                    model, 
                    class_specific=True, 
                    preprocess_input_function=None, 
                    prototype_layer_stride=1, 
                    epoch_number=epoch,
                    save_prototype_class_identity=False
                )

{0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49], 1: []}
{0: [0, 1, 2, 3, 5, 6, 7, 8, 9, 10, 11, 15, 16, 17, 18, 19, 20, 22, 23, 25, 26, 27, 28, 29, 30, 31, 32, 33, 35, 36, 37, 38, 39, 41, 42, 45, 46, 47, 48], 1: [4, 12, 13, 14, 21, 24, 34, 40, 43, 44, 49]}
{0: [0, 1, 4, 5, 6, 9, 15, 18, 19, 20, 22, 25, 26, 29, 30, 31, 35, 36, 37, 38, 41, 42, 45, 47, 48], 1: [2, 3, 7, 8, 10, 11, 12, 13, 14, 16, 17, 21, 23, 24, 27, 28, 32, 33, 34, 39, 40, 43, 44, 46, 49]}
{0: [1, 4, 18, 19, 22, 29, 30, 31, 36, 37, 41, 45], 1: [0, 2, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 20, 21, 23, 24, 25, 26, 27, 28, 32, 33, 34, 35, 38, 39, 40, 42, 43, 44, 46, 47, 48, 49]}
{0: [1, 4, 6, 9, 13, 14, 17, 18, 20, 21, 22, 23, 26, 27, 30, 31, 32, 34, 35, 36, 38, 39, 42, 44], 1: [0, 2, 3, 5, 7, 8, 10, 11, 12, 15, 16, 19, 24, 25, 28, 29, 33, 37, 40, 41, 43, 45, 46, 47, 48

In [1]:
from src.models.backbone_models import denseNet121

model = denseNet121()

kernel, stride, padding = model.conv_info()

print(len(kernel), len(stride), len(padding))
print(kernel)
print(stride)
print(padding)

  warn(


120 120 120
[7, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3]
[2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[3, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0

In [1]:
from src.models.backbone_models import denseFPN_121

model = denseFPN_121()

kernel, stride, padding = model.conv_info()

print(len(kernel), len(stride), len(padding))
print(kernel)
print(stride)
print(padding)

  warn(


127 127 127
[7, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 1, 1, 1, 1, 1, 1]
[2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[3, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1

In [76]:
from src.models.densenet_features import densenet121_features

model = densenet121_features()

kernel, stride, padding = model.conv_info()

print(kernel)
print(stride)
print(padding)

[7, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 2, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 2, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 2, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3]
[2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[3, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0

In [36]:
import torch.nn.functional as F

target = [[1,1,1,1,1, 1,1,1,1,1],[1, 0, 1, 0, 1, 0, 1, 0, 1, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]]
targets = torch.tensor(target)
# one_hot_targets = [F.make]
print(targets)
one_hot_targets = [F.one_hot(t.long().squeeze()) for t in targets]
print(one_hot_targets)

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
        [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]])
[tensor([[0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1]]), tensor([[0, 1],
        [1, 0],
        [0, 1],
        [1, 0],
        [0, 1],
        [1, 0],
        [0, 1],
        [1, 0],
        [0, 1],
        [1, 0]]), tensor([[1, 0],
        [0, 1],
        [1, 0],
        [0, 1],
        [1, 0],
        [0, 1],
        [1, 0],
        [0, 1],
        [1, 0],
        [0, 1]])]


In [3]:
from src.loaders.dataloaderv2 import LIDCDataset
import torchvision.transforms as transforms
from PIL import Image
import torch

labels_file = './dataset/Meta/meta_info_old.csv'
IMG_CHANNELS = 3
IMG_SIZE = 100
CHOSEN_CHARS = [False, True, False, True, True, False, False, True, True]

# labels_file = './dataset/Meta/meta_info_old.csv'
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)

# train set
LIDC_trainset = LIDCDataset(labels_file=labels_file, chosen_chars=CHOSEN_CHARS, auto_split=True, zero_indexed=False, 
                                                        transform=transforms.Compose([transforms.Grayscale(num_output_channels=IMG_CHANNELS), 
                                                                    transforms.Resize(size=(IMG_SIZE, IMG_SIZE), interpolation=Image.BILINEAR), 
                                                                    transforms.ToTensor(), 
                                                                    # transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
                                                                    transforms.Normalize(mean, std)
                                                                    ]),
                                                        train=True)
train_dataloader = torch.utils.data.DataLoader(LIDC_trainset, batch_size=10, shuffle=True, num_workers=4)

LIDC_testset = LIDCDataset(labels_file=labels_file, chosen_chars=CHOSEN_CHARS, auto_split=True, zero_indexed=False, 
                                                        transform=transforms.Compose([transforms.Grayscale(num_output_channels=IMG_CHANNELS), 
                                                                    transforms.Resize(size=(IMG_SIZE, IMG_SIZE), interpolation=Image.BILINEAR), 
                                                                    transforms.ToTensor(), 
                                                                    # transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
                                                                    transforms.Normalize(mean, std)
                                                                    ]), 
                                                        train=False)
test_dataloader = torch.utils.data.DataLoader(LIDC_testset, batch_size=10, shuffle=True, num_workers=4)

print(len(train_dataloader))
print(len(test_dataloader))

['LIDC-IDRI-0001' 'LIDC-IDRI-0002' 'LIDC-IDRI-0003' 'LIDC-IDRI-0004'
 'LIDC-IDRI-0005' 'LIDC-IDRI-0006' 'LIDC-IDRI-0007' 'LIDC-IDRI-0008'
 'LIDC-IDRI-0009' 'LIDC-IDRI-0010' 'LIDC-IDRI-0011' 'LIDC-IDRI-0012'
 'LIDC-IDRI-0013' 'LIDC-IDRI-0014' 'LIDC-IDRI-0015' 'LIDC-IDRI-0016'
 'LIDC-IDRI-0017' 'LIDC-IDRI-0018' 'LIDC-IDRI-0019' 'LIDC-IDRI-0020'
 'LIDC-IDRI-0021' 'LIDC-IDRI-0022' 'LIDC-IDRI-0023' 'LIDC-IDRI-0024'
 'LIDC-IDRI-0025' 'LIDC-IDRI-0026' 'LIDC-IDRI-0027' 'LIDC-IDRI-0029'
 'LIDC-IDRI-0030' 'LIDC-IDRI-0031' 'LIDC-IDRI-0033' 'LIDC-IDRI-0034'
 'LIDC-IDRI-0035' 'LIDC-IDRI-0036' 'LIDC-IDRI-0037' 'LIDC-IDRI-0038'
 'LIDC-IDRI-0039' 'LIDC-IDRI-0040' 'LIDC-IDRI-0041' 'LIDC-IDRI-0042'
 'LIDC-IDRI-0043' 'LIDC-IDRI-0044' 'LIDC-IDRI-0045' 'LIDC-IDRI-0046'
 'LIDC-IDRI-0047' 'LIDC-IDRI-0048' 'LIDC-IDRI-0049' 'LIDC-IDRI-0050'
 'LIDC-IDRI-0051' 'LIDC-IDRI-0052' 'LIDC-IDRI-0053' 'LIDC-IDRI-0054'
 'LIDC-IDRI-0055' 'LIDC-IDRI-0056' 'LIDC-IDRI-0057' 'LIDC-IDRI-0058'
 'LIDC-IDRI-0059' 'LIDC-IDRI-0060'