# Install

# Data

In [None]:
import fastai
from fastai.vision.all import *
from tqdm import tqdm
from glob import glob

In [None]:
SEED = 85
def seed_everything(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(SEED)

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = [224, 224]
BATCH_SIZE = 128
EPOCHS = 10

In [None]:
labels_train_val = pd.read_csv('/kaggle/input/data/train_val_list.txt')
labels_train_val.columns = ['Image_Index']
labels_test = pd.read_csv('/kaggle/input/data/test_list.txt')
labels_test.columns = ['Image_Index']
disease_labels = ['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening',
'Cardiomegaly', 'Nodule', 'Mass', 'Hernia']
# NIH Dataset Labels CSV File 
labels_df = pd.read_csv('/kaggle/input/data/Data_Entry_2017.csv')
labels_df.columns = ['Image_Index', 'Finding_Labels', 'Follow_Up_#', 'Patient_ID',
                  'Patient_Age', 'Patient_Gender', 'View_Position',
                  'Original_Image_Width', 'Original_Image_Height',
                  'Original_Image_Pixel_Spacing_X',
                  'Original_Image_Pixel_Spacing_Y', 'dfd']
# One hot encoding
for diseases in tqdm(disease_labels): 
    labels_df[diseases] = labels_df['Finding_Labels'].map(lambda result: 1 if diseases in result else 0)

# labels_df.to_csv('/kaggle/working/newData.csv')
labels_df=labels_df[labels_df.Finding_Labels != 'No Finding']
#labels_df.head(3)

In [None]:

labels_df['Finding_Labels'] = labels_df['Finding_Labels'].apply(lambda s: [l for l in str(s).split('|')])

num_glob = glob('/kaggle/input/data/*/images/*.png')
img_path = {os.path.basename(x): x for x in num_glob}

labels_df['Paths'] = labels_df['Image_Index'].map(img_path.get)
labels_df.head()

In [None]:
unique_patients = np.unique(labels_df['Patient_ID'])
len(unique_patients)

In [None]:
from sklearn.model_selection import train_test_split

# train-70
# val-10
# test-20
train_val_df_patients, test_df_patients = train_test_split(unique_patients, 
                                   test_size = 0.2,
                                   random_state = SEED,
                                    shuffle= True
                                   )
len(train_val_df_patients)

In [None]:
train_val_df = labels_df[labels_df['Patient_ID'].isin(train_val_df_patients)]

In [None]:
train_val_df.head()

In [None]:
labels_df.shape
print('train_val size', train_val_df.shape[0])
print('test size', labels_df.shape[0] - train_val_df.shape[0])

# Data builder

In [None]:
item_transforms = [
    Resize((224, 224)),
]

batch_transforms = [
    Flip(),
    Rotate(),
    Normalize.from_stats(*imagenet_stats),
]


def get_x(row):
    return row['Paths']

def get_y(row):
    labels = row[disease_labels].tolist()
    return labels

dblock = DataBlock(
    blocks=(ImageBlock, MultiCategoryBlock(encoded=True,vocab=disease_labels)),
                   splitter=RandomSplitter(valid_pct=0.125, seed=SEED),
                   get_x=get_x,
                   get_y=get_y,
                   item_tfms=item_transforms,
                   batch_tfms=batch_transforms
                  )
dls = dblock.dataloaders(train_val_df, bs=64)
# print(dblock.datasets(train_val_merge).train)

# Model

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.nn import functional as F
from copy import deepcopy


# Memory Bank: Store rare/important features
class MemoryBank(nn.Module):
    def __init__(self, feature_dim, bank_size=1000, update_strategy='rarity', 
                 rarity_threshold=0.2, diversity_weight=0.5, momentum=0.9):
        super(MemoryBank, self).__init__()
        self.feature_dim = feature_dim
        self.bank_size = bank_size
        self.update_strategy = update_strategy
        self.rarity_threshold = rarity_threshold
        self.diversity_weight = diversity_weight
        self.momentum = momentum
        
        self.register_buffer('memory', torch.zeros(bank_size, feature_dim))
        self.register_buffer('index', torch.tensor(0))
        self.register_buffer('memory_count', torch.zeros(bank_size))
        self.register_buffer('running_mean', torch.zeros(feature_dim))
        self.register_buffer('running_var', torch.ones(feature_dim))
        self.register_buffer('num_updates', torch.tensor(0))
        
        # For statistical strategy - track running mean of norms
        self.register_buffer('running_mean_norm', torch.tensor(0.0))

    def compute_importance_scores(self, features):
        """Compute importance scores based on selected strategy"""
        batch_size = features.size(0)
        
        if self.update_strategy == 'rarity':
            # Original: Based on L2 norm deviation from mean
            mean_norm = torch.mean(torch.norm(features, dim=1))
            scores = torch.abs(torch.norm(features, dim=1) - mean_norm) / (mean_norm + 1e-8)
            # Lower score = rarer (outlier)
            return scores
            
        elif self.update_strategy == 'statistical':
            # Similar to rarity but uses running mean of norms (statistical tracking)
            if self.num_updates > 0:
                # Compute L2 norm for each sample
                sample_norms = torch.norm(features, dim=1)
                
                # Use running mean of norms instead of batch mean
                scores = torch.abs(sample_norms - self.running_mean_norm) / (self.running_mean_norm + 1e-8)
                # Lower score = rarer (closer to running mean)
                return scores
            else:
                return torch.zeros(batch_size, device=features.device)
        
        elif self.update_strategy == 'entropy':
            # Based on prediction entropy (uncertainty)
            # Requires passing logits/probabilities
            # This is a placeholder - actual implementation needs logits
            norm_features = F.normalize(features, dim=1)
            entropy = -torch.sum(norm_features * torch.log(torch.abs(norm_features) + 1e-8), dim=1)
            return -entropy  # Higher entropy = more uncertain = rarer
        
        elif self.update_strategy == 'diversity':
            # Select features most different from current memory
            if self.index > 0:
                valid_memory = self.memory[:self.index]
                norm_features = F.normalize(features, dim=1)
                norm_memory = F.normalize(valid_memory, dim=1)
                # Max similarity to any memory item
                max_similarity = torch.matmul(norm_features, norm_memory.T).max(dim=1)[0]
                # Lower similarity = more diverse
                return max_similarity
            else:
                return torch.zeros(batch_size, device=features.device)
        
        elif self.update_strategy == 'hybrid':
            # Combine rarity and diversity
            # Rarity component
            mean_norm = torch.mean(torch.norm(features, dim=1))
            rarity = torch.abs(torch.norm(features, dim=1) - mean_norm) / (mean_norm + 1e-8)
            
            # Diversity component
            if self.index > 0:
                valid_memory = self.memory[:self.index]
                norm_features = F.normalize(features, dim=1)
                norm_memory = F.normalize(valid_memory, dim=1)
                max_similarity = torch.matmul(norm_features, norm_memory.T).max(dim=1)[0]
                diversity = 1 - max_similarity
            else:
                diversity = torch.ones(batch_size, device=features.device)
            
            # Weighted combination
            scores = (1 - self.diversity_weight) * (-rarity) + self.diversity_weight * (-diversity)
            return scores
        
        elif self.update_strategy == 'fifo':
            # First In First Out - no scoring needed
            return torch.zeros(batch_size, device=features.device)
        
        elif self.update_strategy == 'reservoir':
            # Reservoir sampling - probabilistic
            return torch.rand(batch_size, device=features.device)
        
        else:
            raise ValueError(f"Unknown update strategy: {self.update_strategy}")

    def update_statistics(self, features):
        """Update running mean and variance"""
        batch_mean = features.mean(dim=0)
        batch_var = features.var(dim=0)
        
        # For statistical strategy - update running mean of norms
        if self.update_strategy == 'statistical':
            sample_norms = torch.norm(features, dim=1)
            batch_mean_norm = sample_norms.mean()
            
            if self.num_updates == 0:
                self.running_mean_norm = batch_mean_norm
            else:
                self.running_mean_norm = self.momentum * self.running_mean_norm + (1 - self.momentum) * batch_mean_norm
        
        if self.num_updates == 0:
            self.running_mean = batch_mean
            self.running_var = batch_var
        else:
            self.running_mean = self.momentum * self.running_mean + (1 - self.momentum) * batch_mean
            self.running_var = self.momentum * self.running_var + (1 - self.momentum) * batch_var
        
        self.num_updates += 1

    def update(self, features, threshold=None):
        """Update memory bank with new features"""
        batch_size = features.size(0)
        
        # Update running statistics
        self.update_statistics(features)
        
        # Compute importance scores
        scores = self.compute_importance_scores(features)
        
        # Apply threshold if specified
        if threshold is None:
            threshold = self.rarity_threshold
        
        if self.update_strategy in ['fifo', 'reservoir']:
            # For FIFO and reservoir, store all or based on probability
            if self.update_strategy == 'fifo':
                mask = torch.ones(batch_size, dtype=torch.bool, device=features.device)
            else:  # reservoir
                # Reservoir sampling: probability decreases as memory fills
                total_seen = self.num_updates * batch_size
                probs = torch.minimum(
                    torch.tensor(self.bank_size / (total_seen + 1e-8), device=features.device),
                    torch.ones(batch_size, device=features.device)
                )
                mask = scores < probs  # scores are random in [0,1]
        else:
            # For other strategies, select based on threshold
            if self.update_strategy in ['entropy', 'diversity', 'hybrid']:
                # Lower scores are better (more rare/diverse)
                mask = scores < torch.quantile(scores, threshold)
            else:
                # rarity, statistical
                mask = scores < threshold
        
        selected_features = features[mask]
        
        if selected_features.size(0) > 0:
            # If memory is full, replace oldest entries
            if self.index + selected_features.size(0) > self.bank_size:
                # Circular buffer - overwrite from beginning
                remaining = self.bank_size - self.index
                self.memory[self.index:] = selected_features[:remaining]
                overflow = selected_features.size(0) - remaining
                if overflow > 0:
                    self.memory[:overflow] = selected_features[remaining:remaining + overflow]
                    self.index = torch.tensor(overflow)
                else:
                    self.index = torch.tensor(self.bank_size)
            else:
                num_to_add = selected_features.size(0)
                self.memory[self.index:self.index + num_to_add] = selected_features
                self.memory_count[self.index:self.index + num_to_add] += 1
                self.index = (self.index + num_to_add) % self.bank_size

    def retrieve(self, query, k=3):
        """Retrieve relevant memories for query features"""
        # Get valid memory entries
        if self.index == 0:
            return torch.zeros_like(query)
        
        valid_memory = self.memory[:self.index] if self.index < self.bank_size else self.memory
        
        norm_query = F.normalize(query, dim=1)
        norm_memory = F.normalize(valid_memory, dim=1)
        similarity = torch.matmul(norm_query, norm_memory.T)
        
        # Avoid self-similarity (similarity = 1)
        mask = similarity < 0.9999
        
        k = min(k, valid_memory.size(0))
        batch_size = query.size(0)
        result = torch.zeros_like(query)
        
        for i in range(batch_size):
            valid_indices = torch.where(mask[i])[0]
            
            if len(valid_indices) == 0:
                continue
            
            valid_similarities = similarity[i, valid_indices]
            k_valid = min(k, valid_similarities.size(0))
            weights, rel_indices = valid_similarities.topk(k_valid)
            abs_indices = valid_indices[rel_indices]
            
            retrieved = valid_memory[abs_indices]
            weights = F.softmax(weights, dim=0).unsqueeze(1).expand_as(retrieved)
            weighted_features = (retrieved * weights).sum(dim=0)
            
            result[i] = weighted_features
            
        return result


# Main Model
class ChestXrayModel(nn.Module):
    def __init__(self, num_classes, model_name='efficientnet_b0', dropout_rate=0.3, 
                 bank_size=512, update_strategy='hybrid', rarity_threshold=0.2,
                 diversity_weight=0.5, memory_momentum=0.9):
        super(ChestXrayModel, self).__init__()

        # Backbone and Final Block
        if model_name == 'resnet50':
            self.base_model = models.resnet50(pretrained=True)
            self.backbone = nn.Sequential(
                self.base_model.conv1, self.base_model.bn1, self.base_model.relu,
                self.base_model.maxpool, self.base_model.layer1, self.base_model.layer2,
                self.base_model.layer3
            )
            self.final_block = self.base_model.layer4
            self.feature_dim = 2048
        elif model_name == 'densenet121':
            self.base_model = models.densenet121(pretrained=True)
            features = list(self.base_model.features.children())
            self.backbone = nn.Sequential(*features[:-1])
            self.final_block = nn.Sequential(features[-1])
            self.feature_dim = 1024
        elif model_name in ['efficientnet_b0', 'efficientnet_b1']:
            self.base_model = models.efficientnet_v2_s(pretrained=True) if model_name == 'efficientnet_b0' else models.efficientnet_b1(pretrained=True)
            features = list(self.base_model.features)
            self.backbone = nn.Sequential(*features[:-1])
            self.final_block = nn.Sequential(features[-1])
            self.feature_dim = self.base_model.features[-1][0].out_channels
        else:
            raise ValueError(f"Model {model_name} not supported")

        self.base_model.fc = nn.Identity() if hasattr(self.base_model, 'fc') else None
        self.base_model.classifier = nn.Identity() if hasattr(self.base_model, 'classifier') else None

        # Memory Bank with configurable strategy
        self.memory_bank = MemoryBank(
            self.feature_dim, 
            bank_size=bank_size,
            update_strategy=update_strategy,
            rarity_threshold=rarity_threshold,
            diversity_weight=diversity_weight,
            momentum=memory_momentum
        )

        # Classifier
        self.classifier = nn.Sequential(
            nn.BatchNorm1d(self.feature_dim),
            nn.Linear(self.feature_dim, 512),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(512, num_classes)
        )
        self.model_name = model_name

    def forward(self, x):
        # Extract features
        backbone_features = self.backbone(x)
        main_features = self.final_block(backbone_features)

        roi_features = main_features
        roi_pooled = F.adaptive_avg_pool2d(roi_features, (1, 1)).flatten(1)

        # Fused features
        fused_features = roi_pooled

        # Update and retrieve from memory bank
        if self.training:
            self.memory_bank.update(fused_features.detach())
        
        memory_features = self.memory_bank.retrieve(fused_features)
        enhanced_features = fused_features + memory_features

        # Classification
        out = self.classifier(enhanced_features)
        return out

# Train Script

In [None]:
from fastai.vision.all import *
import torch.nn as nn
import torch
import torchvision.models as models
from copy import deepcopy

# Focal Loss Implementation
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        """
        Focal Loss to address class imbalance and hard sample mining
        
        Args:
            alpha (float): Weighting factor for positive samples
            gamma (float): Focusing parameter
            reduction (str): Reduction method ('mean', 'sum', or 'none')
        """
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        """
        Compute focal loss
        
        Args:
            inputs (torch.Tensor): Model predictions (logits)
            targets (torch.Tensor): Ground truth labels
        
        Returns:
            torch.Tensor: Computed loss
        """
        # Apply sigmoid to convert logits to probabilities
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        
        # Focal Loss modification
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
        
        if self.reduction == 'mean':
            return torch.mean(F_loss)
        elif self.reduction == 'sum':
            return torch.sum(F_loss)
        else:
            return F_loss

# Asymmetric Loss Implementation
class AsymmetricLoss(nn.Module):
    def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, reduction='mean'):
        """
        Asymmetric Loss to handle class imbalance and hard negative mining
        
        Args:
            gamma_neg (float): Focusing parameter for negative samples
            gamma_pos (float): Focusing parameter for positive samples
            clip (float): Clip the predictions to prevent extreme values
            eps (float): Small epsilon to prevent log(0)
            reduction (str): Reduction method ('mean', 'sum', or 'none')
        """
        super(AsymmetricLoss, self).__init__()
        self.gamma_neg = gamma_neg
        self.gamma_pos = gamma_pos
        self.clip = clip
        self.eps = eps
        self.reduction = reduction

    def forward(self, x, y):
        """
        Compute asymmetric loss
        
        Args:
            x (torch.Tensor): Model predictions (logits)
            y (torch.Tensor): Ground truth labels
        
        Returns:
            torch.Tensor: Computed loss
        """
        # Convert to probabilities
        x_sigmoid = torch.sigmoid(x)
        
        # Clip predictions to prevent extreme values
        xs_min = x_sigmoid.clamp(min=self.eps)
        xs_max = x_sigmoid.clamp(max=1-self.eps)
        
        # Asymmetric term for positive and negative samples
        loss_pos = -y * torch.log(xs_min) * torch.pow(1 - xs_min, self.gamma_pos)
        loss_neg = -(1 - y) * torch.log(1 - xs_max) * torch.pow(xs_max, self.gamma_neg)
        
        loss = loss_pos + loss_neg
        
        if self.reduction == 'mean':
            return torch.mean(loss)
        elif self.reduction == 'sum':
            return torch.sum(loss)
        else:
            return loss

# Create a custom fastai Learner for ChestXrayModel
def create_fastai_learner(
    dls,                            # DataLoaders object
    num_classes=14,                 # Number of output classes
    lr=1e-4,                        # Learning rate
    momentum=0.9,                   # Momentum for the momentum encoder (aligned with model)
    dropout_rate=0.3,               # Dropout rate for classifier
    mixup=False,                    # Whether to use mixup augmentation
    wd=1e-2,                        # Weight decay
    model=None,                     # Pass a pre-instantiated model if you have one
    cbs=None,                       # Additional callbacks
    warmup_epochs=0,                # Number of warm-up epochs for momentum encoder
    loss_type='focal',              # Loss type: 'focal' or 'asymmetric'
    focal_alpha=1,                  # Focal loss alpha parameter
    focal_gamma=2,                  # Focal loss gamma parameter
    asymmetric_gamma_neg=4,         # Asymmetric loss gamma for negative samples
    asymmetric_gamma_pos=1          # Asymmetric loss gamma for positive samples
):
    # Create model if not provided
    if model is None:
        #model = ChestXrayModel(
        #    num_classes=num_classes,
        #    dropout_rate=dropout_rate
            
        #)
        #Statistical outlier detection
        model = ChestXrayModel(
            num_classes=14,
            update_strategy='rarity',
        )
        
        # Pure diversity
        #model = ChestXrayModel(
        #    num_classes=14,
        #    update_strategy='diversity'
        #)
    
    # Register a custom callback to update momentum encoder with warm-up
    class MomentumUpdateCallback(Callback):
        def __init__(self, warmup_epochs):
            super().__init__()
            self.warmup_epochs = warmup_epochs
        
        def after_batch(self):
            if hasattr(self.learn.model, 'momentum_final_block'):
                # Apply warm-up during the first few epochs
                is_warmup = self.learn.epoch < self.warmup_epochs
                self.learn.model.momentum_final_block.update(
                    self.learn.model.final_block, warmup=is_warmup
                )
    
    # Define a custom loss function with multiple loss options
    class ChestXrayLoss(Module):
        def __init__(self, loss_type, **kwargs):
            super().__init__()
            if loss_type == 'focal':
                self.loss = FocalLoss(
                    #alpha=kwargs.get('focal_alpha', 1),
                    #gamma=kwargs.get('focal_gamma', 2)
                )
            elif loss_type == 'asymmetric':
                self.loss = AsymmetricLoss(
                    gamma_neg=kwargs.get('asymmetric_gamma_neg', 4),
                    gamma_pos=kwargs.get('asymmetric_gamma_pos', 1)
                )
            elif loss_type == 'bce':
                self.loss = nn.BCEWithLogitsLoss()
            else:
                raise ValueError(f"Unsupported loss type: {loss_type}")
        
        def forward(self, preds, targets):
            return self.loss(preds, targets)
    
    # Prepare default callbacks
    default_cbs = [
        MomentumUpdateCallback(warmup_epochs),  # Custom callback with warm-up
        SaveModelCallback(monitor='valid_loss'),  # Save best model
        EarlyStoppingCallback(monitor='valid_loss', patience=3)  # Early stopping
    ]
    
    # Add user-specified callbacks
    if cbs is not None:
        if isinstance(cbs, list):
            default_cbs.extend(cbs)
        else:
            default_cbs.append(cbs)
    
    # Create the learner with custom model and loss
    learn = Learner(
        dls, 
        model, 
        loss_func=ChestXrayLoss(
            loss_type=loss_type, 
            focal_alpha=focal_alpha, 
            focal_gamma=focal_gamma,
            asymmetric_gamma_neg=asymmetric_gamma_neg,
            asymmetric_gamma_pos=asymmetric_gamma_pos
        ),
        metrics=[accuracy_multi, F1ScoreMulti(), RocAucMulti()],  # Multi-label metrics
        wd=wd,
        cbs=default_cbs
    )
    
    # Modify the model's forward method to work with fastai's expectations
    class ModelWrapper(nn.Module):
        def __init__(self, model):
            super().__init__()
            self.model = model
        
        def forward(self, x):
            return self.model(x)  # Model returns only class predictions
    
    # Wrap the model
    learn.model = ModelWrapper(learn.model)
    
    # Enable mixed precision training if available
    learn.to_fp16()
    
    # Add mixup if requested (suitable for multi-label with BCE)
    if mixup:
        learn.add_cb(MixUp())
    
    # Add a progress bar callback for better training visualization
    learn.add_cb(ProgressCallback())
    
    # Add CSV logger to track metrics
    learn.add_cb(CSVLogger())
    
    return learn

# Train

In [None]:
from fastai.vision.all import *

cbs=[
    SaveModelCallback(monitor='valid_loss', min_delta=0.0001, with_opt=True),
    EarlyStoppingCallback(monitor='valid_loss', min_delta=0.001, patience=5),
    ShowGraphCallback()
    ]

learn = create_fastai_learner(dls,cbs=cbs,loss_type='bce')

#learn.model = torch.nn.DataParallel(learn.model)

In [None]:
lrs = learn.lr_find(suggest_funcs=(minimum, steep, valley, slide))
print('intial learning rate=', lrs.minimum)

In [None]:
print('intial learning rate=', lrs.valley)

In [None]:
learn.unfreeze()
learn.model
sum(p.numel() for p in learn.model.parameters() if p.requires_grad)

In [None]:
learn.fine_tune(freeze_epochs=3,epochs=20, base_lr=lrs.valley)

In [None]:
learn.save('fastai_momentum_inner_gate_70_20_10_effnetb1')

In [None]:
#learn = learn.load('/kaggle/working/models/fastai_momentum_cross_spatial_70_20_10')

# Stage 2

In [None]:
import fastai
from fastai.vision.all import *
from tqdm import tqdm
from glob import glob

In [None]:
#learn = learn.load('/kaggle/working/models/fastai_momentum_cross_spatial_70_20_10')
SEED = 85
def seed_everything(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(SEED)
labels_train_val = pd.read_csv('/kaggle/input/data/train_val_list.txt')
labels_train_val.columns = ['Image_Index']
labels_test = pd.read_csv('/kaggle/input/data/test_list.txt')
labels_test.columns = ['Image_Index']
disease_labels = ['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening',
'Cardiomegaly', 'Nodule', 'Mass', 'Hernia']
# NIH Dataset Labels CSV File 
labels_df = pd.read_csv('/kaggle/input/data/Data_Entry_2017.csv')
labels_df.columns = ['Image_Index', 'Finding_Labels', 'Follow_Up_#', 'Patient_ID',
                  'Patient_Age', 'Patient_Gender', 'View_Position',
                  'Original_Image_Width', 'Original_Image_Height',
                  'Original_Image_Pixel_Spacing_X',
                  'Original_Image_Pixel_Spacing_Y', 'dfd']
# One hot encoding
for diseases in tqdm(disease_labels): 
    labels_df[diseases] = labels_df['Finding_Labels'].map(lambda result: 1 if diseases in result else 0)

# labels_df.to_csv('/kaggle/working/newData.csv')
# labels_df=labels_df[labels_df.Finding_Labels != 'No Finding']
# #labels_df.head(3)

labels_df['Finding_Labels'] = labels_df['Finding_Labels'].apply(lambda s: [l for l in str(s).split('|')])

num_glob = glob('/kaggle/input/data/*/images/*.png')
img_path = {os.path.basename(x): x for x in num_glob}

labels_df['Paths'] = labels_df['Image_Index'].map(img_path.get)
labels_df.head()

In [None]:
unique_patients = np.unique(labels_df['Patient_ID'])
len(unique_patients)

In [None]:
from sklearn.model_selection import train_test_split

# train-70
# val-10
# test-20
train_val_df_patients, test_df_patients = train_test_split(unique_patients, 
                                   test_size = 0.2,
                                   random_state = SEED,
                                    shuffle= True
                                   )
len(train_val_df_patients)

In [None]:
train_val_df = labels_df[labels_df['Patient_ID'].isin(train_val_df_patients)]

In [None]:
labels_df.shape
print('train_val size', train_val_df.shape[0])
print('test size', labels_df.shape[0] - train_val_df.shape[0])

In [None]:
item_transforms = [
    Resize((224, 224)),
]

batch_transforms = [
    Flip(),
    Rotate(),
    Normalize.from_stats(*imagenet_stats),
]


def get_x(row):
    return row['Paths']

def get_y(row):
    labels = row[disease_labels].tolist()
    return labels

dblock = DataBlock(
    blocks=(ImageBlock, MultiCategoryBlock(encoded=True,vocab=disease_labels)),
                   splitter=RandomSplitter(valid_pct=0.125, seed=SEED),
                   get_x=get_x,
                   get_y=get_y,
                   item_tfms=item_transforms,
                   batch_tfms=batch_transforms
                  )
dls = dblock.dataloaders(train_val_df, bs=128)
# print(dblock.datasets(train_val_merge).train)

In [None]:
from fastai.vision.all import *

cbs=[
    SaveModelCallback(monitor='valid_loss', min_delta=0.0001, with_opt=True),
    EarlyStoppingCallback(monitor='valid_loss', min_delta=0.001, patience=5),
    ShowGraphCallback()
    ]

learn = create_fastai_learner(dls,cbs=cbs,momentum=0.9999,loss_type='bce')
learn = learn.load('/kaggle/working/models/model')
#learn.model = torch.nn.DataParallel(learn.model)

In [None]:
learn.unfreeze()
learn.fit_one_cycle(5, slice(2e-5, 8e-5))

In [None]:
learn.save('fastai_momentum_cross_spatial_70_20_10_stage2')

# Evaluation

In [None]:
from sklearn.metrics import roc_auc_score
def get_roc_auc(learner):
    #arch = model_arch
    # learner = vision_learner(dls, arch, metrics=[accuracy_multi, F1ScoreMulti(), RocAucMulti()])
    # learner.model = torch.nn.DataParallel(learner.model)
    # learner.load(model_path)
    # learner.to('cuda')
    learner.freeze()
    preds, y_test = learner.get_preds(ds_idx=1)
    roc_auc = roc_auc_score(y_test, preds)
    
    scores=[]
    for i in range(0,14):
        label_roc_auc_score=roc_auc_score(y_test[:,i],preds[:,i])
        scores.append(label_roc_auc_score)
    print('ROC_AUC_Labels:', list(zip(disease_labels,scores)))   
    
#     print('AVERAGE', sum(scores)/len(scores))
    print(f'SCORE: {roc_auc}')
    del learner
    #gc.collect()
    return {
        'roc_auc': roc_auc,
        'preds': preds,
        'y_test': y_test
    }

In [None]:
modelv1_result= get_roc_auc(learn)
preds = modelv1_result['preds']
torch.save(preds, 'modelv1_result.pt')

In [None]:
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc, precision_score, recall_score, f1_score

def get_roc_auc(learner, threshold=0.5):
    learner.freeze()
    preds, y_test = learner.get_preds(ds_idx=1)
    
    # Calculate ROC AUC (this doesn't actually use a threshold, it considers all possible thresholds)
    roc_auc = roc_auc_score(y_test, preds)
    
    # Apply threshold to get binary predictions
    binary_preds = (preds > threshold).float()
    
    # Calculate other metrics with the chosen threshold
    precision = precision_score(y_test.cpu().numpy(), binary_preds.cpu().numpy(), average='macro', zero_division=0)
    recall = recall_score(y_test.cpu().numpy(), binary_preds.cpu().numpy(), average='macro', zero_division=0)
    f1 = f1_score(y_test.cpu().numpy(), binary_preds.cpu().numpy(), average='macro', zero_division=0)
    
    # Per-class metrics with custom threshold
    scores = []
    precision_scores = []
    recall_scores = []
    f1_scores = []
    
    for i in range(0, 14):
        # ROC AUC per class (threshold-independent)
        label_roc_auc_score = roc_auc_score(y_test[:, i], preds[:, i])
        scores.append(label_roc_auc_score)
        
        # Precision, recall, F1 with custom threshold
        class_precision = precision_score(y_test[:, i].cpu().numpy(), binary_preds[:, i].cpu().numpy(), zero_division=0)
        class_recall = recall_score(y_test[:, i].cpu().numpy(), binary_preds[:, i].cpu().numpy(), zero_division=0)
        class_f1 = f1_score(y_test[:, i].cpu().numpy(), binary_preds[:, i].cpu().numpy(), zero_division=0)
        
        precision_scores.append(class_precision)
        recall_scores.append(class_recall)
        f1_scores.append(class_f1)
    
    print(f'Using threshold: {threshold}')
    print('ROC_AUC_Labels:', list(zip(disease_labels, scores)))   
    print(f'Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}')
    print(f'ROC AUC Score: {roc_auc:.4f}')
    
    del learner
    #gc.collect()
    return {
        'roc_auc': roc_auc,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'class_scores': {
            'roc_auc': scores,
            'precision': precision_scores,
            'recall': recall_scores,
            'f1': f1_scores
        },
        'preds': preds,
        'y_test': y_test,
        'binary_preds': binary_preds
    }
