# Install

# Data

In [None]:
import fastai
from fastai.vision.all import *
from tqdm import tqdm
from glob import glob

In [None]:
SEED = 85
def seed_everything(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(SEED)

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = [224, 224]
BATCH_SIZE = 128
EPOCHS = 10

In [None]:
labels_train_val = pd.read_csv('/kaggle/input/data/train_val_list.txt')
labels_train_val.columns = ['Image_Index']
labels_test = pd.read_csv('/kaggle/input/data/test_list.txt')
labels_test.columns = ['Image_Index']
disease_labels = ['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening',
'Cardiomegaly', 'Nodule', 'Mass', 'Hernia']
# NIH Dataset Labels CSV File 
labels_df = pd.read_csv('/kaggle/input/data/Data_Entry_2017.csv')
labels_df.columns = ['Image_Index', 'Finding_Labels', 'Follow_Up_#', 'Patient_ID',
                  'Patient_Age', 'Patient_Gender', 'View_Position',
                  'Original_Image_Width', 'Original_Image_Height',
                  'Original_Image_Pixel_Spacing_X',
                  'Original_Image_Pixel_Spacing_Y', 'dfd']
# One hot encoding
for diseases in tqdm(disease_labels): 
    labels_df[diseases] = labels_df['Finding_Labels'].map(lambda result: 1 if diseases in result else 0)

# labels_df.to_csv('/kaggle/working/newData.csv')
labels_df=labels_df[labels_df.Finding_Labels != 'No Finding']
#labels_df.head(3)

In [None]:

labels_df['Finding_Labels'] = labels_df['Finding_Labels'].apply(lambda s: [l for l in str(s).split('|')])

num_glob = glob('/kaggle/input/data/*/images/*.png')
img_path = {os.path.basename(x): x for x in num_glob}

labels_df['Paths'] = labels_df['Image_Index'].map(img_path.get)
labels_df.head()

In [None]:
unique_patients = np.unique(labels_df['Patient_ID'])
len(unique_patients)

In [None]:
from sklearn.model_selection import train_test_split

# train-70
# val-10
# test-20
train_val_df_patients, test_df_patients = train_test_split(unique_patients, 
                                   test_size = 0.2,
                                   random_state = SEED,
                                    shuffle= True
                                   )
len(train_val_df_patients)

In [None]:
train_val_df = labels_df[labels_df['Patient_ID'].isin(train_val_df_patients)]

In [None]:
train_val_df.head()

In [None]:
labels_df.shape
print('train_val size', train_val_df.shape[0])
print('test size', labels_df.shape[0] - train_val_df.shape[0])

# Data builder

In [None]:
item_transforms = [
    Resize((224, 224)),
]

batch_transforms = [
    Flip(),
    Rotate(),
    Normalize.from_stats(*imagenet_stats),
]


def get_x(row):
    return row['Paths']

def get_y(row):
    labels = row[disease_labels].tolist()
    return labels

dblock = DataBlock(
    blocks=(ImageBlock, MultiCategoryBlock(encoded=True,vocab=disease_labels)),
                   splitter=RandomSplitter(valid_pct=0.125, seed=SEED),
                   get_x=get_x,
                   get_y=get_y,
                   item_tfms=item_transforms,
                   batch_tfms=batch_transforms
                  )
dls = dblock.dataloaders(train_val_df, bs=64)
# print(dblock.datasets(train_val_merge).train)

# Model

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.nn import functional as F
from copy import deepcopy

# Configuration block for all parameters
class ModelConfig:
    # Momentum Encoder parameters
    MOMENTUM = 0.9999
    
    # Spatial Attention parameters
    ATTENTION_REDUCTION = 8
    
    # Memory Bank parameters
    BANK_SIZE = 512
    RARITY_THRESHOLD = 0.2
    RETRIEVAL_K = 3
    
    # Model architecture parameters
    DROPOUT_RATE = 0.3
    HIDDEN_DIM = 512
    
    @staticmethod
    def get_feature_dim(model_name):
        if model_name == 'resnet50':
            return 2048
        elif model_name == 'densenet121':
            return 1024
        elif model_name in ['efficientnet_b0', 'efficientnet_b1']:
            # This is a placeholder - actual value is determined at runtime
            return None
        else:
            raise ValueError(f"Model {model_name} not supported")

# Momentum Encoder: Simplified to final block copy
class MomentumFinalBlock(nn.Module):
    def __init__(self, final_block, momentum=None):
        super(MomentumFinalBlock, self).__init__()
        self.momentum = momentum if momentum is not None else ModelConfig.MOMENTUM
        self.final_block = deepcopy(final_block)
        for param in self.final_block.parameters():
            param.requires_grad = False

    def forward(self, x):
        return self.final_block(x)

    def update(self, main_final_block):
        for param_q, param_k in zip(main_final_block.parameters(), self.final_block.parameters()):
            param_k.data = param_k.data * self.momentum + param_q.data * (1. - self.momentum)

# Spatial Attention: Lightweight ROI selection
class SpatialAttention(nn.Module):
    def __init__(self, in_channels, reduction=None):
        super(SpatialAttention, self).__init__()
        reduction = reduction if reduction is not None else ModelConfig.ATTENTION_REDUCTION
        reduced_channels = max(in_channels // reduction, 8)
        
        self.conv1 = nn.Conv2d(in_channels, reduced_channels, kernel_size=1)
        self.conv3 = nn.Conv2d(in_channels, reduced_channels, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(in_channels, reduced_channels, kernel_size=5, padding=2)
        
        self.spatial_att = nn.Sequential(
            nn.Conv2d(reduced_channels * 3, 1, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        f1 = self.conv1(x)
        f3 = self.conv3(x)
        f5 = self.conv5(x)
        
        features = torch.cat([f1, f3, f5], dim=1)
        attention = self.spatial_att(features)  # [batch_size, 1, H, W]
        return attention

# Memory Bank: Store rare/important features
class MemoryBank(nn.Module):
    def __init__(self, feature_dim, bank_size=None, rarity_threshold=None):
        super(MemoryBank, self).__init__()
        self.feature_dim = feature_dim
        self.bank_size = bank_size if bank_size is not None else ModelConfig.BANK_SIZE
        self.rarity_threshold = rarity_threshold if rarity_threshold is not None else ModelConfig.RARITY_THRESHOLD
        
        self.register_buffer('memory', torch.zeros(self.bank_size, feature_dim))
        self.register_buffer('index', torch.tensor(0))

    def update(self, features, rarity_scores):
        batch_size = features.size(0)
        mask = rarity_scores < self.rarity_threshold
        rare_features = features[mask]
        
        if rare_features.size(0) > 0:
            num_to_add = min(rare_features.size(0), self.bank_size - self.index.item())
            if num_to_add > 0:
                self.memory[self.index:self.index + num_to_add] = rare_features[:num_to_add]
                self.index = (self.index + num_to_add) % self.bank_size

    def retrieve(self, query, k=None):
        k = k if k is not None else ModelConfig.RETRIEVAL_K
        valid_memory = self.memory
        if valid_memory.size(0) == 0:
            return torch.zeros_like(query)
        
        norm_query = F.normalize(query, dim=1)
        norm_memory = F.normalize(valid_memory, dim=1)
        similarity = torch.matmul(norm_query, norm_memory.T)
        
        # Create a mask for entries where similarity != 1
        mask = similarity != 1.0
        
        k = min(k, valid_memory.size(0))
        
        # Initialize containers for results
        batch_size = query.size(0)
        result = torch.zeros_like(query)
        
        for i in range(batch_size):
            # Get indices where similarity is not 1 for this query
            valid_indices = torch.where(mask[i])[0]
            
            if len(valid_indices) == 0:
                # If all memories have similarity=1, just return zeros
                continue
            
            # Get similarities only for valid indices
            valid_similarities = similarity[i, valid_indices]
            
            # Get top-k among valid similarities
            k_valid = min(k, valid_similarities.size(0))
            weights, rel_indices = valid_similarities.topk(k_valid)
            
            # Convert relative indices to absolute indices
            abs_indices = valid_indices[rel_indices]
            
            # Get features for these indices
            retrieved = valid_memory[abs_indices]
            
            # Apply weights
            weights = weights.unsqueeze(1).expand_as(retrieved)
            weighted_features = (retrieved * weights).sum(dim=0)
            
            result[i] = weighted_features
            
        return result

# Main Model
class ChestXrayModel(nn.Module):
    def __init__(self, num_classes, model_name='efficientnet_b0', config=None):
        super(ChestXrayModel, self).__init__()
        
        # Use provided config or the default ModelConfig
        self.config = config if config is not None else ModelConfig
        
        # Backbone and Final Block
        if model_name == 'resnet50':
            self.base_model = models.resnet50(pretrained=True)
            self.backbone = nn.Sequential(
                self.base_model.conv1, self.base_model.bn1, self.base_model.relu,
                self.base_model.maxpool, self.base_model.layer1, self.base_model.layer2,
                self.base_model.layer3
            )
            self.final_block = self.base_model.layer4
            self.feature_dim = 2048
        elif model_name == 'densenet121':
            self.base_model = models.densenet121(pretrained=True)
            features = list(self.base_model.features.children())
            self.backbone = nn.Sequential(*features[:-1])
            self.final_block = nn.Sequential(features[-1])
            self.feature_dim = 1024
        elif model_name in ['efficientnet_b0', 'efficientnet_b1']:
            self.base_model = models.efficientnet_v2_s(pretrained=True) if model_name == 'efficientnet_b0' else models.efficientnet_b1(pretrained=True)
            features = list(self.base_model.features)
            self.backbone = nn.Sequential(*features[:-1])
            self.final_block = nn.Sequential(features[-1])
            self.feature_dim = self.base_model.features[-1][0].out_channels
        else:
            raise ValueError(f"Model {model_name} not supported")

        self.base_model.fc = nn.Identity() if hasattr(self.base_model, 'fc') else None
        self.base_model.classifier = nn.Identity() if hasattr(self.base_model, 'classifier') else None

        # Momentum Encoder
        self.momentum_final_block = MomentumFinalBlock(self.final_block, momentum=self.config.MOMENTUM)

        # Spatial Attention
        self.spatial_attention = SpatialAttention(self.feature_dim, reduction=self.config.ATTENTION_REDUCTION)

        # Memory Bank
        self.memory_bank = MemoryBank(
            self.feature_dim, 
            bank_size=self.config.BANK_SIZE, 
            rarity_threshold=self.config.RARITY_THRESHOLD
        )

        # Classifier
        self.classifier = nn.Sequential(
            nn.BatchNorm1d(self.feature_dim),
            nn.Linear(self.feature_dim, self.config.HIDDEN_DIM),
            nn.ReLU(),
            nn.Dropout(self.config.DROPOUT_RATE),
            nn.Linear(self.config.HIDDEN_DIM, num_classes)
        )
        self.model_name = model_name

    def forward(self, x):
        # Extract features
        backbone_features = self.backbone(x)
        main_features = self.final_block(backbone_features)
        with torch.no_grad():
            momentum_features = self.momentum_final_block(backbone_features)

        # Spatial attention and ROI extraction
        attention_map = self.spatial_attention(main_features)
        roi_features = main_features * attention_map
        roi_pooled = F.adaptive_avg_pool2d(roi_features, (1, 1)).flatten(1)

        # Momentum features
        momentum_pooled = F.adaptive_avg_pool2d(momentum_features, (1, 1)).flatten(1)

        # Combine ROI and momentum features (simple addition)
        fused_features = roi_pooled + momentum_pooled

        # Update and retrieve from memory bank
        if self.training:
            mean_norm = torch.mean(torch.norm(fused_features, dim=1))
            rarity_scores = torch.abs(torch.norm(fused_features, dim=1) - mean_norm) / mean_norm
            self.memory_bank.update(fused_features.detach(), rarity_scores)
        
        memory_features = self.memory_bank.retrieve(fused_features, k=self.config.RETRIEVAL_K)
        enhanced_features = fused_features + memory_features

        # Classification
        out = self.classifier(enhanced_features)

        # Update momentum encoder during training
        if self.training:
             self.momentum_final_block.update(self.final_block)

        return out



# Train Script

In [None]:
from fastai.vision.all import *
import torch.nn as nn
import torch
import torchvision.models as models
from copy import deepcopy

# Focal Loss Implementation
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        """
        Focal Loss to address class imbalance and hard sample mining
        
        Args:
            alpha (float): Weighting factor for positive samples
            gamma (float): Focusing parameter
            reduction (str): Reduction method ('mean', 'sum', or 'none')
        """
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        """
        Compute focal loss
        
        Args:
            inputs (torch.Tensor): Model predictions (logits)
            targets (torch.Tensor): Ground truth labels
        
        Returns:
            torch.Tensor: Computed loss
        """
        # Apply sigmoid to convert logits to probabilities
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        
        # Focal Loss modification
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
        
        if self.reduction == 'mean':
            return torch.mean(F_loss)
        elif self.reduction == 'sum':
            return torch.sum(F_loss)
        else:
            return F_loss

# Asymmetric Loss Implementation
class AsymmetricLoss(nn.Module):
    def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, reduction='mean'):
        """
        Asymmetric Loss to handle class imbalance and hard negative mining
        
        Args:
            gamma_neg (float): Focusing parameter for negative samples
            gamma_pos (float): Focusing parameter for positive samples
            clip (float): Clip the predictions to prevent extreme values
            eps (float): Small epsilon to prevent log(0)
            reduction (str): Reduction method ('mean', 'sum', or 'none')
        """
        super(AsymmetricLoss, self).__init__()
        self.gamma_neg = gamma_neg
        self.gamma_pos = gamma_pos
        self.clip = clip
        self.eps = eps
        self.reduction = reduction

    def forward(self, x, y):
        """
        Compute asymmetric loss
        
        Args:
            x (torch.Tensor): Model predictions (logits)
            y (torch.Tensor): Ground truth labels
        
        Returns:
            torch.Tensor: Computed loss
        """
        # Convert to probabilities
        x_sigmoid = torch.sigmoid(x)
        
        # Clip predictions to prevent extreme values
        xs_min = x_sigmoid.clamp(min=self.eps)
        xs_max = x_sigmoid.clamp(max=1-self.eps)
        
        # Asymmetric term for positive and negative samples
        loss_pos = -y * torch.log(xs_min) * torch.pow(1 - xs_min, self.gamma_pos)
        loss_neg = -(1 - y) * torch.log(1 - xs_max) * torch.pow(xs_max, self.gamma_neg)
        
        loss = loss_pos + loss_neg
        
        if self.reduction == 'mean':
            return torch.mean(loss)
        elif self.reduction == 'sum':
            return torch.sum(loss)
        else:
            return loss

# Create a custom fastai Learner for ChestXrayModel
def create_fastai_learner(
    dls,                            # DataLoaders object
    num_classes=14,                 # Number of output classes
    lr=1e-4,                        # Learning rate
    momentum=0.9,                   # Momentum for the momentum encoder (aligned with model)
    dropout_rate=0.3,               # Dropout rate for classifier
    mixup=False,                    # Whether to use mixup augmentation
    wd=1e-2,                        # Weight decay
    model=None,                     # Pass a pre-instantiated model if you have one
    cbs=None,                       # Additional callbacks
    warmup_epochs=0,                # Number of warm-up epochs for momentum encoder
    loss_type='focal',              # Loss type: 'focal' or 'asymmetric'
    focal_alpha=1,                  # Focal loss alpha parameter
    focal_gamma=2,                  # Focal loss gamma parameter
    asymmetric_gamma_neg=4,         # Asymmetric loss gamma for negative samples
    asymmetric_gamma_pos=1          # Asymmetric loss gamma for positive samples
):
    # Create model if not provided
    if model is None:
        model = ChestXrayModel(
            num_classes=num_classes
        )
    
    # Register a custom callback to update momentum encoder with warm-up
    class MomentumUpdateCallback(Callback):
        def __init__(self, warmup_epochs):
            super().__init__()
            self.warmup_epochs = warmup_epochs
        
        def after_batch(self):
            if hasattr(self.learn.model, 'momentum_final_block'):
                # Apply warm-up during the first few epochs
                is_warmup = self.learn.epoch < self.warmup_epochs
                self.learn.model.momentum_final_block.update(
                    self.learn.model.final_block, warmup=is_warmup
                )
    
    # Define a custom loss function with multiple loss options
    class ChestXrayLoss(Module):
        def __init__(self, loss_type, **kwargs):
            super().__init__()
            if loss_type == 'focal':
                self.loss = FocalLoss(
                    #alpha=kwargs.get('focal_alpha', 1),
                    #gamma=kwargs.get('focal_gamma', 2)
                )
            elif loss_type == 'asymmetric':
                self.loss = AsymmetricLoss(
                    gamma_neg=kwargs.get('asymmetric_gamma_neg', 4),
                    gamma_pos=kwargs.get('asymmetric_gamma_pos', 1)
                )
            elif loss_type == 'bce':
                self.loss = nn.BCEWithLogitsLoss()
            else:
                raise ValueError(f"Unsupported loss type: {loss_type}")
        
        def forward(self, preds, targets):
            return self.loss(preds, targets)
    
    # Prepare default callbacks
    default_cbs = [
        MomentumUpdateCallback(warmup_epochs),  # Custom callback with warm-up
        SaveModelCallback(monitor='valid_loss'),  # Save best model
        EarlyStoppingCallback(monitor='valid_loss', patience=3)  # Early stopping
    ]
    
    # Add user-specified callbacks
    if cbs is not None:
        if isinstance(cbs, list):
            default_cbs.extend(cbs)
        else:
            default_cbs.append(cbs)
    
    # Create the learner with custom model and loss
    learn = Learner(
        dls, 
        model, 
        loss_func=ChestXrayLoss(
            loss_type=loss_type, 
            focal_alpha=focal_alpha, 
            focal_gamma=focal_gamma,
            asymmetric_gamma_neg=asymmetric_gamma_neg,
            asymmetric_gamma_pos=asymmetric_gamma_pos
        ),
        metrics=[accuracy_multi, F1ScoreMulti(), RocAucMulti()],  # Multi-label metrics
        wd=wd,
        cbs=default_cbs
    )
    
    # Modify the model's forward method to work with fastai's expectations
    class ModelWrapper(nn.Module):
        def __init__(self, model):
            super().__init__()
            self.model = model
        
        def forward(self, x):
            return self.model(x)  # Model returns only class predictions
    
    # Wrap the model
    learn.model = ModelWrapper(learn.model)
    
    # Enable mixed precision training if available
    learn.to_fp16()
    
    # Add mixup if requested (suitable for multi-label with BCE)
    if mixup:
        learn.add_cb(MixUp())
    
    # Add a progress bar callback for better training visualization
    learn.add_cb(ProgressCallback())
    
    # Add CSV logger to track metrics
    learn.add_cb(CSVLogger())
    
    return learn

# Stage 2

In [None]:
import fastai
from fastai.vision.all import *
from tqdm import tqdm
from glob import glob

In [None]:
#learn = learn.load('/kaggle/working/models/fastai_momentum_cross_spatial_70_20_10')
SEED = 85
def seed_everything(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(SEED)
labels_train_val = pd.read_csv('/kaggle/input/data/train_val_list.txt')
labels_train_val.columns = ['Image_Index']
labels_test = pd.read_csv('/kaggle/input/data/test_list.txt')
labels_test.columns = ['Image_Index']
disease_labels = ['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening',
'Cardiomegaly', 'Nodule', 'Mass', 'Hernia']
# NIH Dataset Labels CSV File 
labels_df = pd.read_csv('/kaggle/input/data/Data_Entry_2017.csv')
labels_df.columns = ['Image_Index', 'Finding_Labels', 'Follow_Up_#', 'Patient_ID',
                  'Patient_Age', 'Patient_Gender', 'View_Position',
                  'Original_Image_Width', 'Original_Image_Height',
                  'Original_Image_Pixel_Spacing_X',
                  'Original_Image_Pixel_Spacing_Y', 'dfd']
# One hot encoding
for diseases in tqdm(disease_labels): 
    labels_df[diseases] = labels_df['Finding_Labels'].map(lambda result: 1 if diseases in result else 0)

# labels_df.to_csv('/kaggle/working/newData.csv')
# labels_df=labels_df[labels_df.Finding_Labels != 'No Finding']
# #labels_df.head(3)

labels_df['Finding_Labels'] = labels_df['Finding_Labels'].apply(lambda s: [l for l in str(s).split('|')])

num_glob = glob('/kaggle/input/data/*/images/*.png')
img_path = {os.path.basename(x): x for x in num_glob}

labels_df['Paths'] = labels_df['Image_Index'].map(img_path.get)
labels_df.head()

In [None]:
unique_patients = np.unique(labels_df['Patient_ID'])
len(unique_patients)

In [None]:
from sklearn.model_selection import train_test_split

# train-70
# val-10
# test-20
train_val_df_patients, test_df_patients = train_test_split(unique_patients, 
                                   test_size = 0.2,
                                   random_state = SEED,
                                    shuffle= True
                                   )
len(train_val_df_patients)

In [None]:
train_val_df = labels_df[labels_df['Patient_ID'].isin(train_val_df_patients)]
test_df = labels_df[labels_df['Patient_ID'].isin(test_df_patients)]

In [None]:
labels_df.shape
print('train_val size', train_val_df.shape[0])
print('test size', labels_df.shape[0] - train_val_df.shape[0])

In [None]:
item_transforms = [
    Resize((224, 224)),
]

batch_transforms = [
    Flip(),
    Rotate(),
    Normalize.from_stats(*imagenet_stats),
]


def get_x(row):
    return row['Paths']

def get_y(row):
    labels = row[disease_labels].tolist()
    return labels

dblock = DataBlock(
    blocks=(ImageBlock, MultiCategoryBlock(encoded=True,vocab=disease_labels)),
                   splitter=RandomSplitter(valid_pct=0.125, seed=SEED),
                   get_x=get_x,
                   get_y=get_y,
                   item_tfms=item_transforms,
                   batch_tfms=batch_transforms
                  )
dls_phase2 = dblock.dataloaders(train_val_df, bs=128)
cbs_phase2=[
    SaveModelCallback(monitor='valid_loss', min_delta=0.0001, with_opt=True),
    EarlyStoppingCallback(monitor='valid_loss', min_delta=0.001, patience=5),
    ShowGraphCallback()
    ]
dls_test = dblock.dataloaders(test_df, bs=32, shuffle=False)
# print(dblock.datasets(train_val_merge).train)

In [None]:
from sklearn.metrics import roc_auc_score
def get_roc_auc(learner):
    #arch = model_arch
    # learner = vision_learner(dls, arch, metrics=[accuracy_multi, F1ScoreMulti(), RocAucMulti()])
    # learner.model = torch.nn.DataParallel(learner.model)
    # learner.load(model_path)
    # learner.to('cuda')
    learner.freeze()
    preds, y_test = learner.get_preds(ds_idx=1)
    roc_auc = roc_auc_score(y_test, preds)
    
    scores=[]
    for i in range(0,14):
        label_roc_auc_score=roc_auc_score(y_test[:,i],preds[:,i])
        scores.append(label_roc_auc_score)
    print('ROC_AUC_Labels:', list(zip(disease_labels,scores)))   
    
#     print('AVERAGE', sum(scores)/len(scores))
    print(f'SCORE: {roc_auc}')
    del learner
    #gc.collect()
    return {
        'roc_auc': roc_auc,
        'preds': preds,
        'y_test': y_test
    }

In [None]:
# from fastai.vision.all import *

# cbs=[
#     SaveModelCallback(monitor='valid_loss', min_delta=0.0001, with_opt=True),
#     EarlyStoppingCallback(monitor='valid_loss', min_delta=0.001, patience=5),
#     ShowGraphCallback()
#     ]

# learn = create_fastai_learner(dls,cbs=cbs,loss_type='bce')
# lrs = learn.lr_find(suggest_funcs=(minimum, steep, valley, slide))
# print('intial learning rate=', lrs.valley)
# learn.fine_tune(freeze_epochs=3,epochs=20, base_lr=lrs.valley)
# #learn.model = torch.nn.DataParallel(learn.model)
# print('--------------Phase1-Done---------------')



# print('--------------Begin-Phase2---------------')
# learn = create_fastai_learner(dls_phase2,cbs=cbs_phase2,loss_type='asymmetric')
# learn = learn.load('/kaggle/working/models/model')
# learn.unfreeze()
# learn.fit_one_cycle(5, slice(2e-5, 8e-5))
# print('--------------Phase2-Done---------------')




# print('--------------Begin-Testing---------------')
# learn = create_fastai_learner(dls_test,cbs=cbs_phase2,loss_type='asymmetric')
# learn = learn.load('/kaggle/working/models/model')
# model_result= get_roc_auc(learn)
# preds = modelv1_result['preds']
# torch.save(preds, 'modelv1_result.pt')


In [None]:
from fastai.vision.all import *
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_auc_score
import itertools
import time
import os


# Define the parameter grid
momentum_values = [0.99]
k_values = [2, 3, 4, 5]
threshold_values = [0.2]

# Define results directory
results_dir = '/kaggle/working/parameter_search_results'
os.makedirs(results_dir, exist_ok=True)

# Prepare to store results
results = []

# Define base callbacks for all runs
base_cbs = [
    SaveModelCallback(monitor='valid_loss', min_delta=0.0001, with_opt=True),
    EarlyStoppingCallback(monitor='valid_loss', min_delta=0.001, patience=3),
    ShowGraphCallback()
    ]
def run_experiment(momentum, k, threshold, exp_name):
    """Run a single experiment with specified parameters."""
    print(f"\n=== Running experiment: momentum={momentum}, k={k}, threshold={threshold} ===")
    
    # Modify the ModelConfig class attributes directly
    ModelConfig.MOMENTUM = momentum
    ModelConfig.RETRIEVAL_K = k
    ModelConfig.RARITY_THRESHOLD = threshold
    
    # Create a model-specific save callback
    save_cb = SaveModelCallback(monitor='valid_loss', min_delta=0.0001, 
                             fname=f"model_{exp_name}", with_opt=True)
    
    # Combine callbacks
    cbs = base_cbs + [save_cb]
    
    try:
        start_time = time.time()
        
        # Phase 1: Initial training
        print('--------------Begin-Phase1---------------')
        learn = create_fastai_learner(dls, cbs=cbs, loss_type='bce')
        
        # Find learning rate
        lrs = learn.lr_find(suggest_funcs=(minimum, steep, valley, slide))
        lr = lrs.valley if lrs.valley is not None else 1e-4
        print('initial learning rate=', lr)
        
        # Training
        learn.fine_tune(freeze_epochs=3, epochs=20, base_lr=lr)
        
        # Phase 2
        print('--------------Phase1-Done---------------')
        print('--------------Begin-Phase2---------------')
        learn = create_fastai_learner(dls_phase2, cbs=cbs, loss_type='asymmetric')
        learn = learn.load(f"/kaggle/working/models/model_{exp_name}")
        learn.unfreeze()
        learn.fit_one_cycle(5, slice(2e-5, 8e-5))
        
        # Testing phase
        print('--------------Phase2-Done---------------')
        print('--------------Begin-Testing---------------')
        test_learn = create_fastai_learner(dls_test, cbs=cbs, loss_type='asymmetric')
        test_learn = test_learn.load(f"/kaggle/working/models/model_{exp_name}")
        
        # Get results
        model_result = get_roc_auc(test_learn)
        preds = model_result['preds']
        auc_scores = model_result['class_auc']
        mean_auc = model_result['mean_auc']
        
        # Save predictions
        torch.save(preds, f'{results_dir}/preds_{exp_name}.pt')
        
        training_time = time.time() - start_time
        
        result = {
            'momentum': momentum,
            'k': k,
            'threshold': threshold,
            'mean_auc': mean_auc,
            'class_auc': auc_scores,
            'training_time': training_time,
            'exp_name': exp_name
        }
        
        print(f"Experiment completed. Mean AUC: {mean_auc:.4f}, Time: {training_time:.1f}s")
        return result
        
    except Exception as e:
        print(f"Error in experiment: {str(e)}")
        return {
            'momentum': momentum,
            'k': k,
            'threshold': threshold,
            'mean_auc': float('nan'),
            'class_auc': [],
            'training_time': float('nan'),
            'exp_name': exp_name,
            'error': str(e)
        }

# Main script
if __name__ == "__main__":
    # Get all parameter combinations
    all_combinations = list(itertools.product(momentum_values, k_values, threshold_values))
    total_experiments = len(all_combinations)
    
    print(f"Starting grid search with {total_experiments} parameter combinations")
    
    # Run all experiments
    for i, (momentum, k, threshold) in enumerate(all_combinations):
        exp_name = f"m{momentum}_k{k}_t{threshold}"
        print(f"\nExperiment {i+1}/{total_experiments}")
        
        result = run_experiment(momentum, k, threshold, exp_name)
        results.append(result)
        
        # Save intermediate results after each experiment
        pd.DataFrame(results).to_csv(f'{results_dir}/grid_search_results.csv', index=False)
    
    # Convert results to DataFrame for analysis
    results_df = pd.DataFrame(results)
    
    # Find best parameters
    if not results_df.empty and not all(np.isnan(results_df['mean_auc'])):
        best_idx = results_df['mean_auc'].idxmax()
        best_params = results_df.iloc[best_idx]
        
        print("\n=== Best Parameter Combination ===")
        print(f"Momentum: {best_params['momentum']}")
        print(f"K-value: {best_params['k']}")
        print(f"Threshold: {best_params['threshold']}")
        print(f"Mean AUC: {best_params['mean_auc']:.4f}")
        print(f"Experiment name: {best_params['exp_name']}")
        
        # Create visualizations
        
        # 1. Heatmaps for each momentum value
        for m in momentum_values:
            df_m = results_df[results_df['momentum'] == m]
            if not df_m.empty:
                pivot = df_m.pivot_table(index='k', columns='threshold', values='mean_auc')
                
                plt.figure(figsize=(10, 8))
                sns.heatmap(pivot, annot=True, fmt='.4f', cmap='viridis')
                plt.title(f'Mean AUC for Momentum = {m}')
                plt.tight_layout()
                plt.savefig(f'{results_dir}/heatmap_momentum_{m}.png')
                plt.close()
        
        # 2. Effect of individual parameters
        plt.figure(figsize=(18, 6))
        
        # Effect of momentum
        plt.subplot(1, 3, 1)
        momentum_effect = results_df.groupby('momentum')['mean_auc'].mean().reset_index()
        sns.barplot(x='momentum', y='mean_auc', data=momentum_effect)
        plt.title('Effect of Momentum')
        plt.ylim(momentum_effect['mean_auc'].min() * 0.99, momentum_effect['mean_auc'].max() * 1.01)
        
        # Effect of k
        plt.subplot(1, 3, 2)
        k_effect = results_df.groupby('k')['mean_auc'].mean().reset_index()
        sns.barplot(x='k', y='mean_auc', data=k_effect)
        plt.title('Effect of K Value')
        plt.ylim(k_effect['mean_auc'].min() * 0.99, k_effect['mean_auc'].max() * 1.01)
        
        # Effect of threshold
        plt.subplot(1, 3, 3)
        threshold_effect = results_df.groupby('threshold')['mean_auc'].mean().reset_index()
        sns.barplot(x='threshold', y='mean_auc', data=threshold_effect)
        plt.title('Effect of Threshold')
        plt.ylim(threshold_effect['mean_auc'].min() * 0.99, threshold_effect['mean_auc'].max() * 1.01)
        
        plt.tight_layout()
        plt.savefig(f'{results_dir}/parameter_effects.png')
        plt.close()
        
        # 3. Top 10 combinations
        top10 = results_df.nlargest(10, 'mean_auc')
        plt.figure(figsize=(14, 8))
        
        # Create labels for x-axis
        labels = [f"m={row['momentum']}\nk={row['k']}\nt={row['threshold']}" for _, row in top10.iterrows()]
        
        plt.bar(range(len(labels)), top10['mean_auc'])
        plt.xticks(range(len(labels)), labels, rotation=45, ha='right')
        plt.title('Top 10 Parameter Combinations')
        plt.ylabel('Mean AUC')
        plt.tight_layout()
        plt.savefig(f'{results_dir}/top10_combinations.png')
        plt.close()
        

# Evaluation