In [1]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from torchmetrics.functional import accuracy, recall, precision, auroc
from torch.utils.data import Dataset, DataLoader
import timm
import numpy as np
import pandas as pd
import random
from typing import List
from tqdm import tqdm
from torchmetrics.functional import accuracy, recall, precision, auroc
from PIL import Image, ImageEnhance, ImageOps, ImageFilter
import os
from glob import glob
from sklearn.model_selection import train_test_split

BASE_DIR = "/kaggle/input/grand-xray-slam-division-a"
label_columns = [
    'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum',
    'Fracture', 'Lung Lesion', 'Lung Opacity', 'No Finding', 'Pleural Effusion',
    'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices'
]


def setup_seed(seed=None):
    if seed is None:
        return
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Augmentation Spaces and Utilities

In [2]:
# used to freeze layers in a model
def freeze_all(model):
    for param in model.parameters():
        param.requires_grad=False

def get_resnet18(num_classes=14, fine_tune=True):
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

    if fine_tune:
        freeze_all(model)
        for param in model.layer4.parameters():
            param.requires_grad = True

    model.fc = nn.Linear(512, num_classes)
    return model

def get_resnet34(num_classes=14, fine_tune=True):
    model = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)

    if fine_tune:
        freeze_all(model)
        for param in model.layer4.parameters():
            param.requires_grad = True

    model.fc = nn.Linear(512, num_classes)
    return model

def get_effnetb0(num_classes=14, fine_tune=True):
    model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)

    if fine_tune:
        freeze_all(model)
        for idx in range(6, 8):
            for param in model.features[idx].parameters():
                param.requires_grad = True

    model.classifier = nn.Sequential(
        nn.Dropout(p=0.2, inplace=True),
        nn.Linear(in_features=1280, out_features=num_classes)
    )
    
    return model

def get_convnext_tiny(num_classes=14, fine_tune=True):
    model = models.convnext_tiny(weights=models.ConvNeXt_Tiny_Weights.IMAGENET1K_V1)

    if fine_tune:
        freeze_all(model)
        for param in model.features[6].parameters():
            param.requires_grad = True

    model.classifier[2] = nn.Linear(in_features=768, out_features=num_classes)
    return model

# returns getter function and number of in features of the classifier
models_ = {
    "res18": (get_resnet18, 512),
    "res34": (get_resnet34, 512),
    "effb0": (get_effnetb0, 1280),
    "convnext" : (get_convnext_tiny, 768),
}

In [3]:
# Set up transforms
basic_transforms = transforms.Compose([
    transforms.Resize((224, 224)), # resize to 224x224
    transforms.ToTensor(), # convert to tensor [0,1]
    transforms.Normalize( # normalize with ImageNet stats
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# custom transforms with adaptive magnitude
def shear_x(img, magnitude):
    level = magnitude * 0.3 * random.choice([-1, 1])
    return img.transform(img.size, Image.AFFINE, (1, level, 0, 0, 1, 0))

def shear_y(img, magnitude):
    level = magnitude * 0.3 * random.choice([-1, 1])
    return img.transform(img.size, Image.AFFINE, (1, 0, 0, level, 1, 0))

def translate_x(img, magnitude):
    max_shift = 0.3 * img.size[0]
    level = magnitude * max_shift * random.choice([-1, 1])
    return img.transform(img.size, Image.AFFINE, (1, 0, level, 0, 1, 0))

def translate_y(img, magnitude):
    max_shift = 0.3 * img.size[1]
    level = magnitude * max_shift * random.choice([-1, 1])
    return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, level))

def rotate(img, magnitude):
    degrees = magnitude * 30 * random.choice([-1, 1])
    return img.rotate(degrees)

def contrast(img, magnitude):
    enhancer = ImageEnhance.Contrast(img)
    factor = 1.0 + (magnitude * random.choice([-0.9, 0.9]))
    return enhancer.enhance(factor)

def brightness(img, magnitude):
    enhancer = ImageEnhance.Brightness(img)
    factor = 1.0 + (magnitude * random.choice([-0.9, 0.9]))
    return enhancer.enhance(factor)

def sharpness(img, magnitude):
    enhancer = ImageEnhance.Sharpness(img)
    factor = 1.0 + (magnitude * random.choice([-0.9, 0.9]))
    return enhancer.enhance(factor)

def equalize(img, magnitude=None):
    return ImageOps.equalize(img)

def gaussian_blur(img, magnitude):
    radius = magnitude * 2
    return img.filter(ImageFilter.GaussianBlur(radius))

def identity(img, magnitude=None):
    return img

augmentation_space = [
    shear_x, shear_y, 
    translate_x, translate_y,
    rotate, equalize, 
    contrast, brightness,
    sharpness, identity
]

In [None]:
# AdaAugment can be used to control augmentation magnitudes and operations


class AdaAugment:

    def __init__(self, rand_m, rand_t):
        self.key_transform = {}
        self.key_magnitude = {}
        self.transforms = [
            shear_x, shear_y, 
            translate_x, translate_y,
            rotate, equalize, 
            contrast, brightness,
            sharpness, identity
        ]
        self.rand_m = rand_m
        self.rand_t = rand_t

    def set(self, keys, m=None, transform_idx=None):
        if m is not None:
            for i, key in enumerate(keys):
                self.key_magnitude[key] = m[i].cpu().detach()
            
        if transform_idx is not None:
            for i, key in enumerate(keys):
                self.key_transform[key] = transform_idx[i].cpu().detach()
    
    def __call__(self, key, img):
        """
        Apply an augmentation to the image based on stored magnitudes and transforms.
        
        Args:
            key: unique identifier for the sample
            img: input image to transform
    
        Returns:
            Augmented image
        """
        # --- Determine magnitude ---
        magnitude = self.key_magnitude.get(key)
        
        if magnitude is None:
            # No stored magnitude
            if self.rand_m:
                # Random magnitude
                if random.random() < 0.6:
                    return img  # skip transform 40% of the time
                magnitude = random.random()
            else:
                magnitude = 0.0
        else:
            # Magnitude exists
            if not self.rand_m and not self.rand_t:
                # Use magnitude corresponding to stored transform index
                transform_idx = int(self.key_transform.get(key))
                if isinstance(magnitude, (list, torch.Tensor)):
                    magnitude = float(magnitude[transform_idx])
                else:
                    magnitude = float(magnitude)
            else:
                magnitude = float(magnitude)
    
        # --- Determine transform ---
        transform_idx = self.key_transform.get(key)
        
        if transform_idx is None:
            # No stored transform
            transform = random.choice(self.transforms) if self.rand_t else self.transforms[-1]
        else:
            transform = self.transforms[int(transform_idx)]
    
        # --- Apply transform ---
        return transform(img, magnitude)


class SimpleAugment:

    def __init__(self, p=0.6, keys=None):
        self.keys = keys
        self.key_transform = {}
        self.key_magnitude = {}
        self.transforms = [
            shear_x, shear_y, 
            translate_x, translate_y,
            rotate, equalize, 
            contrast, brightness, 
            sharpness, identity
        ]
        self.transform_idx = list(range(len(self.transforms)))
        self.p = p

    def get(self, keys):
        collected_magnitude = []
        collected_transform = []
        collected_keys = []
        for key in keys:
            if key in self.keys:
                collected_keys.append(key)
                collected_magnitude.append(self.key_magnitude[key])
                collected_transform.append(self.key_transform[key])
        return collected_keys, torch.stack(collected_transform), torch.stack(collected_magnitude)
        
    def __call__(self, key, img):
        """
        Apply simple reandom augmentations to an image
        
        Args:
            key: unique identifier for the sample
            img: input image to transform
    
        Returns:
            Augmented image
        """
        if random.random() < self.p and key in self.keys:
            transform_idx = random.choice(self.transform_idx)
            transform = self.transforms[transform_idx]
            magnitude = random.random()
            self.key_transform[key] = transform_idx
            self.key_magnitude[key] = magnitude
            return transform(img, magnitude)

        else:
            return img

In [5]:
# Focal Loss for fine tuninng

class FocalLoss(nn.Module):

    def __init__(self, weights:List=None, gamma=2.0, reduction="mean"):
        super().__init__()

        self.weights = weights.to(device)
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logit, target):
        bce_loss = nn.functional.binary_cross_entropy_with_logits(
            logit, target,
            reduction="none"
        )
        probs = torch.exp(-bce_loss)
        F_loss = self.weights.to(target.device) * (1-probs) ** self.gamma * bce_loss

        if self.reduction == "mean":
            return F_loss.mean()
        elif self.reduction == "none":
            return F_loss   
        else:
            return F_loss.sum()

In [6]:
# Saves per sample information such as previous loss etc.

class ValueMemory:
    def __init__(self):
        """
        Stores the last value per key (no EMA).
        """
        self.values = {}

    def __call__(self, keys, vals):
        """
        keys: list of sample identifiers
        vals: torch.Tensor of shape (len(keys), D)
        Returns: current stored values, previous stored values
        """
        stored_list = []
        new_list = []

        for i, key in enumerate(keys):
            val = vals[i]
            if key not in self.values:
                old = val.clone()  # nothing stored yet → use current
            else:
                old = self.values[key]

            self.values[key] = val  # overwrite with last value

            stored_list.append(old.unsqueeze(0))
            new_list.append(self.values[key].unsqueeze(0))

        stored = torch.cat(stored_list, dim=0)  # previous values
        return stored

    def get(self, key):
        """
        Access the stored value for a single key
        """
        return self.values.get(key, None)

    def get_multi(self, keys):
        """
        Access stored values for multiple keys
        """
        return torch.stack([self.values[k] for k in keys])

### Setting up the AdaAugment Agent
+ Critic => returns value for a state
+ Actor => Parameterizes a beta distribution that is used to sample augmentation magnitudes
+ Controller => Parameterizes a categorical distribution that is used to sample transformations from the augmentation space

In [None]:
class Actor(nn.Module):
    def __init__(self, in_features, hidden, out_features):
        super().__init__()
        self.linear1 = nn.Linear(in_features, hidden)
        self.layernorm1 = nn.LayerNorm(hidden)
        self.linear2 = nn.Linear(hidden, hidden)
        #self.layernorm2 = nn.LayerNorm(hidden)
        self.alpha_head = nn.Linear(hidden, out_features)
        self.beta_head = nn.Linear(hidden, out_features)

    def forward(self, x):
        x = torch.relu(self.layernorm1(self.linear1(x)))
        x = torch.relu(self.linear2(x))
        return torch.softmax(self.alpha_head(x), dim=-1) + 1, torch.softmax(self.beta_head(x), dim=-1) + 1

    def get_dist(self, x):
        alpha, beta = self(x)
        dist = torch.distributions.Beta(alpha, beta)
        return dist


class Critic(nn.Module):

    def __init__(self, in_features, hidden):
        super().__init__()
        self.linear1 = nn.Linear(in_features, hidden)
        self.layernorm1 = nn.LayerNorm(hidden)
        self.linear2 = nn.Linear(hidden, hidden)
        #self.layernorm2 = nn.LayerNorm(hidden)
        self.head = nn.Linear(hidden, 1)

    def forward(self, x):
        x = torch.relu(self.layernorm1(self.linear1(x)))
        x = torch.relu(self.linear2(x))
        x = self.head(x)
        return x


class Controller(nn.Module):

    def __init__(self, in_features, hidden):
        super().__init__()
        self.linear1 = nn.Linear(in_features, hidden)
        self.layernorm1 = nn.LayerNorm(hidden)
        self.linear2 = nn.Linear(hidden, hidden)
        #self.layernorm2 = nn.LayerNorm(hidden)
        self.head = nn.Linear(hidden, len(augmentation_space))

    def forward(self, x):
        x = torch.relu(self.layernorm1(self.linear1(x)))
        x = torch.relu(self.linear2(x))
        x = self.head(x)
        return x
    
    def get_dist(self, x):
        out = self(x)
        dist = torch.distributions.Categorical(logits=out)
        return dist


class Agent(nn.Module):

    def __init__(
        self, 
        in_features,
        out_features, 
        control=False, 
        actor=False, 
        advantage_reg=True,
    ):
        super().__init__()
        self.val_memory = ValueMemory()
        self.loss_memory = ValueMemory()
        self.advantage_reg = advantage_reg

        self.critic = Critic(in_features=in_features, hidden=128)

        self.store_ = {}

        self.control_ = control
        self.actor_ = actor
        if control:
            self.controller = Controller(in_features=in_features, hidden=128)
            self.controller_optimizer = torch.optim.Adam(
                params=self.controller.parameters(), lr=3e-5, weight_decay=5e-4
            )

        if actor:
            self.actor = Actor(in_features=in_features, hidden=128, out_features=out_features) 

            self.actor_optimizer = torch.optim.Adam(
                params=self.actor.parameters(), lr=3e-5, weight_decay=5e-4
            )
        self.critic_optimizer = torch.optim.Adam(
            params=self.critic.parameters(), lr=3e-5, weight_decay=5e-4
        )

    def forward(self, state):
        action_actor = None
        action_controller = None

        if self.actor_:  
            dist = self.actor.get_dist(state.detach())
            action_actor = dist.sample()
            self.store_["action_actor"] = action_actor
            self.store_["dist_actor"] = dist

        if self.control_:
            dist = self.controller.get_dist(state.detach())
            action_controller = dist.sample()
            self.store_["action_controller"] = action_controller
            self.store_["dist_controller"] = dist

        return action_actor, action_controller       
    
    def set_action(self, action_actor, action_controller):
        self.store_["action_controller"] = action_controller
        self.store_["action_actor"] = action_actor

    def update(self, key, state, reward):
        value = self.critic(state)
        prev_state = self.val_memory(key, state)
        prev_value = self.critic(prev_state)
        
        with torch.no_grad():
            td_target = reward + 0.99 * value

        if self.actor_ and self.control_:
            dist_actor, action_actor = self.store_["dist_actor"], self.store_["action_actor"]
            dist_control, action_control = self.store_["dist_controller"], self.store_["action_controller"]
        
            # Select magnitude corresponding to the chosen augmentation
            chosen_magnitude = action_actor.gather(1, action_control.unsqueeze(1))

            # Controller chooses augmentation index
            aug_idx = action_control.unsqueeze(1)
            
            # Select alpha, beta for chosen augmentation
            alpha_chosen = dist_actor.concentration1.gather(1, aug_idx)
            beta_chosen  = dist_actor.concentration0.gather(1, aug_idx)
            
            # Build a Beta distribution only for the chosen augmentation
            dist_chosen = torch.distributions.Beta(alpha_chosen.squeeze(1), beta_chosen.squeeze(1))
            
            # Compute log_prob of chosen magnitude
            log_prob_actor = dist_chosen.log_prob(chosen_magnitude.squeeze(1))
        
            # Log-prob for augmentation choice
            log_prob_control = dist_control.log_prob(action_control)
        
            # Total log prob
            log_prob = log_prob_actor + log_prob_control

        elif self.actor_:
            dist, action = self.store_["dist_actor"], self.store_["action_actor"]
            log_prob = dist.log_prob(action)

        elif self.control_:
            dist, action = self.store_["dist_controller"], self.store_["action_controller"]
            log_prob = dist.log_prob(action)

        advantage = td_target - prev_value.detach()

        if self.advantage_reg:
            advantage = (advantage - advantage.mean()) (advantage.std() + 1e-8)

        actor_loss = -(log_prob * advantage).mean()

        if self.control_:
            self.controller_optimizer.zero_grad()
        if self.actor_:
            self.actor_optimizer.zero_grad()
        actor_loss.backward()
        if self.actor_:
            self.actor_optimizer.step()
        if self.control_:
            self.controller_optimizer.step()
        critic_loss = torch.nn.functional.mse_loss(td_target, prev_value)
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        return actor_loss, critic_loss

### Dataset
+ Class weights
+ Dataset class

In [8]:
def get_class_weights(df):
    weights = []
    label_columns = [
        'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum',
        'Fracture', 'Lung Lesion', 'Lung Opacity', 'No Finding', 'Pleural Effusion',
        'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices'
    ]

    for label in label_columns:
        weight = len(df) / (df[label].sum() + 1e-6)
        weights.append(weight)

    return weights

In [9]:
class XRayDataset(Dataset):

    def __init__(self, df, train=True, transform=None):
        super().__init__()
        self.df = df
        self.train = train
        self.transform = transform
        self.label_columns = [
            'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum',
            'Fracture', 'Lung Lesion', 'Lung Opacity', 'No Finding', 'Pleural Effusion',
            'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices'
        ]
        if train:
            class_weights = np.array(get_class_weights(df), dtype=np.float32)
    
            sample_weights = (df[label_columns].values.astype(np.float32) * class_weights).mean(axis=1)
            sample_weights[sample_weights == 0] = 1.0
            self.df["weight"] = sample_weights

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = self.df.iloc[idx]["img_path"]
        img = Image.open(img_path).convert("RGB")
        label = self.df.iloc[idx][self.label_columns].values.astype(np.float32)
        key = self.df.iloc[idx]["Image_name"]
        
        if self.train and self.transform is not None:
            img = self.transform(key, img)

        img = basic_transforms(img)

        if not self.train:
            return key, img

        return key, img, torch.tensor(label, dtype=torch.float32)

### Adapt model to new dataset

In [10]:
train_csv = pd.read_csv("/kaggle/input/grand-xray-slam-division-a/train1.csv")
train_csv["img_path"] = train_csv["Image_name"].apply(lambda x: os.path.join(BASE_DIR, "train1", x))
submission_df = pd.read_csv("/kaggle/input/grand-xray-slam-division-a/sample_submission_1.csv")
train_df, val_df = train_test_split(
    train_csv,
    test_size=0.2,
    random_state=42,
    stratify=train_csv["No Finding"]
)

submission_df["img_path"] = submission_df["Image_name"].apply(lambda x : os.path.join(BASE_DIR, "test1", x))
class_weights = get_class_weights(train_df)
zipped_class_weights = list(zip(label_columns, class_weights))

print("Train dataset:")
for class_, weight in zipped_class_weights:
    print(f"{class_}: {weight:.2f}\n")


print("Validation dataset:")

class_weights = get_class_weights(val_df)
zipped_class_weights = list(zip(label_columns, class_weights))
for class_, weight in zipped_class_weights:
    print(f"{class_}: {weight:.2f}\n")


Atelectasis: 2.765

Cardiomegaly: 3.073

Consolidation: 3.666

Edema: 4.034

Enlarged Cardiomediastinum: 2.841

Fracture: 7.221

Lung Lesion: 9.104

Lung Opacity: 2.211

No Finding: 3.162

Pleural Effusion: 3.137

Pleural Other: 15.225

Pneumonia: 7.571

Pneumothorax: 12.202

Support Devices: 2.852



In [None]:
current_state = {}
def _register_head_hook(model):
    def hook(module, input, output):
        current_state['head_input'] = input[0].detach()
    
    if isinstance(model, models.EfficientNet):
        return model.classifier[-1].register_forward_hook(hook)
    elif isinstance(model, models.ConvNeXt):
        return model.classifier[2].register_forward_hook(hook)
    else:
        return model.classifier[-1].register_forward_hook(hook)

def _remove_head_hook(_hook_handle):
    _hook_handle.remove()

In [None]:
# Dataset
img_to_augment = random.choices(train_df["Image_name"].to_list(), k=int(len(train_df) * 0.3))

simpleaugment = SimpleAugment(p=1, keys=img_to_augment)

train_set = XRayDataset(train_df, train=True, transform=simpleaugment)
val_set = XRayDataset(val_df, train=True)
test_set = XRayDataset(submission_df, train=False)

# DataLoader
batch_size = 64
train_loader = DataLoader(
    train_set, 
    batch_size=batch_size, 
    shuffle=True, 
    pin_memory=True,
    num_workers=os.cpu_count()
)
val_loader = DataLoader(
    val_set, 
    batch_size=batch_size, 
    pin_memory=True,
    num_workers=os.cpu_count()
)
test_loader = DataLoader(
    test_set, 
    batch_size=batch_size, 
    pin_memory=True,
    num_workers=os.cpu_count()
)

In [None]:
# Models and functions
randm, randt = True, True
model_name = "convnext"
get_model, in_features = models_[model_name]
model = get_model(fine_tune=False)
observe = True

observing_agent = Agent(in_features, len(augmentation_space), not randt, not randm)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = FocalLoss(weights=torch.tensor(class_weights).to(device), reduction="none")

Downloading: "https://download.pytorch.org/models/convnext_tiny-983f1562.pth" to /root/.cache/torch/hub/checkpoints/convnext_tiny-983f1562.pth
100%|██████████| 109M/109M [00:00<00:00, 201MB/s] 


In [None]:
epochs = 5
r_weight = 0.5
r_norm = True
model.to(device)
observing_agent.to(device)
best_score = 0
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=1e-3,
    epochs=epochs,
    steps_per_epoch=len(train_loader),
    pct_start=0.1
)

for epoch in range(epochs):
    
    # Training
    if observe:
        handle = _register_head_hook(model)
    
    model.train()
    train_loss = 0
    for keys, imgs, targets in tqdm(train_loader, desc="Training"):
        imgs, targets = imgs.to(device, non_blocking=True), targets.to(device, non_blocking=True)

        logits = model(imgs)
        state = current_state["head_input"]

        loss = loss_fn(logits, targets)

        prev_loss = observing_agent.loss_memory(keys, loss.detach())
        probs = torch.sigmoid(logits)

        if randm or randt:
    
            _, _ = observing_agent(state)
            intersection_keys, transform_idx, magnitude =  simpleaugment.get(keys)
            observing_agent.set_action(magnitude, transform_idx)
    
            entropy = torch.sum(probs * torch.log(probs + 1e-8), dim=1).unsqueeze(1)
            reward = r_weight * (loss.detach().mean() - prev_loss.mean()) + (1 - r_weight) * entropy
            if r_norm:
                reward = (reward - reward.mean()) / (reward.std() - 1e-8)
    
            actor_loss, critic_loss = observing_agent.update(keys, state, reward)
            total_actor_loss += actor_loss.item()
            total_critic_loss += critic_loss.item()

        train_loss += loss.item()
    train_loss /= len(train_loader)

    # Validation

    _remove_head_hook(handle)

    model.eval()
    val_loss = 0

    all_probs = []
    all_targets = []

    with torch.inference_mode():
        for keys, imgs, targets in tqdm(val_loader, desc="Validation"):
            imgs, targets = imgs.to(device, non_blocking=True), targets.to(device, non_blocking=True)

            logits = model(imgs)
            loss = loss_fn(logits, targets)
            probs = torch.sigmoid(logits)

            val_loss += loss.item()

            all_probs.append(probs)
            all_targets.append(targets)

    val_loss /= len(val_loader)
    all_probs = torch.cat(all_probs, dim=0)
    all_targets = torch.cat(all_targets, dim=0)

    all_preds = (all_probs >= 0.5).int()

    val_accuracy = accuracy(all_preds, all_targets.int(), task="multilabel", num_labels=14)
    val_precision = precision(all_preds, all_targets.int(), task="multilabel", num_labels=14)
    val_recall = recall(all_preds, all_targets.int(), task="multilabel", num_labels=14)
    val_auroc = auroc(all_probs, all_targets.int(), task="multilabel", num_labels=14)


    if val_auroc > best_score:
        best_score = val_auroc
        initial_best_model = model.state_dict()
        torch.save(initial_best_model, "initial_best_model.pth")
        print(f"New best auroc score: {best_score:.4f}")

    print(
        f"Epoch {epoch+1}/{epochs} | "
        f"Train Loss: {train_loss:.4f} | "
        f"Actor Loss: {actor_loss:.4f} | "
        f"Critic Loss: {critic_loss:.4f} | "
        f"Val Loss: {val_loss:.4f} | "
        f"Acc: {val_accuracy:.4f} | "
        f"Prec: {val_precision:.4f} | "
        f"Rec: {val_recall:.4f} | "
        f"AUROC: {val_auroc:.4f}"
    )

Training: 100%|██████████| 1343/1343 [38:17<00:00,  1.71s/it]
Validation: 100%|██████████| 336/336 [08:38<00:00,  1.54s/it]


New best auroc score: 0.8748
Epoch 1/3 | Train Loss: 0.4860 | Val Loss: 0.4489 | Acc: 0.8568 | Prec: 0.7364 | Rec: 0.6812 | AUROC: 0.8748


Training: 100%|██████████| 1343/1343 [40:05<00:00,  1.79s/it]
Validation: 100%|██████████| 336/336 [08:33<00:00,  1.53s/it]


New best auroc score: 0.8948
Epoch 2/3 | Train Loss: 0.4189 | Val Loss: 0.4031 | Acc: 0.8704 | Prec: 0.7719 | Rec: 0.6966 | AUROC: 0.8948


Training: 100%|██████████| 1343/1343 [40:14<00:00,  1.80s/it]
Validation: 100%|██████████| 336/336 [08:35<00:00,  1.54s/it]


New best auroc score: 0.8965
Epoch 3/3 | Train Loss: 0.3940 | Val Loss: 0.4032 | Acc: 0.8735 | Prec: 0.8017 | Rec: 0.6680 | AUROC: 0.8965


### Finetuning using AdaAugment and Focal Loss

### Things to try out here:

+ Add back the Layernorm after Linear1
+ Advantage Norm
+ Reward Norm
+ Lr scheduler

In [None]:
randm, randt = False, False

ada_augment=AdaAugment(rand_m=randm, rand_t=randt)
train_set = XRayDataset(train_df, train=True, transform=ada_augment)

w_ = torch.tensor(train_set.df["weight"], dtype=torch.float32)
sampler = torch.utils.data.WeightedRandomSampler(
    weights=w_,
    num_samples=len(w_),
    replacement=True
)

# New dataset using adaptive augmentations
train_loader = DataLoader(
    train_set, 
    batch_size=batch_size,
    #sampler=sampler,
    shuffle=True, 
    pin_memory=True,
    num_workers=os.cpu_count()
)

get_model, in_features = models_[model_name]

model = get_model(fine_tune=True)

model.load_state_dict(torch.load("/kaggle/working/initial_best_model.pth", map_location=device, weights_only=False))

optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-5)

loss_fn = FocalLoss(weights=torch.tensor(class_weights, dtype=torch.float32), reduction="none")

agent = Agent(in_features, len(augmentation_space), not randt, not randm)
current_state = {}
best_score = 0

In [None]:
epochs = 5
r_weight = 1.0
r_norm = True
model.to(device)
agent.to(device)

    
for epoch in range(epochs):
    # Training
    
    model.train()
    handle = _register_head_hook(model)
    
    train_loss, total_actor_loss, total_critic_loss = 0, 0, 0
    for keys, imgs, targets in tqdm(train_loader, desc="Training"):
        imgs, targets = imgs.to(device, non_blocking=True), targets.to(device, non_blocking=True)

        logits = model(imgs)
        state = current_state["head_input"]

        loss = loss_fn(logits, targets)

        prev_loss = agent.loss_memory(keys, loss.detach())
        probs = torch.sigmoid(logits)

        if randm or randt:
    
            action, transform_idx = agent(state)
            ada_augment.set(keys, action, transform_idx)
    
            entropy = torch.sum(probs * torch.log(probs + 1e-8), dim=1).unsqueeze(1)
            reward = r_weight * (loss.detach().mean() - prev_loss.mean()) + (1 - r_weight) * entropy
            if r_norm:
                reward = (reward - reward.mean()) / (reward.std() - 1e-8)
    
            actor_loss, critic_loss = agent.update(keys, state, reward)
            total_actor_loss += actor_loss.item()
            total_critic_loss += critic_loss.item()

        loss = loss.mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    train_loss /= len(train_loader)
    total_actor_loss /= len(train_loader)
    total_critic_loss /= len(train_loader)

    _remove_head_hook(handle)
    # Validation
    model.eval()
    val_loss = 0

    all_probs = []
    all_targets = []

    with torch.inference_mode():
        for keys, imgs, targets in tqdm(val_loader, desc="Validation"):
            imgs, targets = imgs.to(device, non_blocking=True), targets.to(device, non_blocking=True)

            logits = model(imgs)
            loss = loss_fn(logits, targets)
            loss = loss.mean()
            probs = torch.sigmoid(logits)

            val_loss += loss.item()

            all_probs.append(probs)
            all_targets.append(targets)

    val_loss /= len(val_loader)
    all_probs = torch.cat(all_probs, dim=0)
    all_targets = torch.cat(all_targets, dim=0)

    all_preds = (all_probs >= 0.5).int()

    val_accuracy = accuracy(all_preds, all_targets.int(), task="multilabel", num_labels=14)
    val_precision = precision(all_preds, all_targets.int(), task="multilabel", num_labels=14)
    val_recall = recall(all_preds, all_targets.int(), task="multilabel", num_labels=14)
    val_auroc = auroc(all_probs, all_targets.int(), task="multilabel", num_labels=14)

    if val_auroc > best_score:
        best_score = val_auroc
        initial_best_model = model.state_dict()
        torch.save(initial_best_model, "best_model.pth")
        print(f"New best auroc score: {best_score:.4f}")

    print(
        f"Epoch {epoch+1}/{epochs} | "
        f"Train Loss: {train_loss:.4f} | "
        f"Actor Loss: {total_actor_loss:.4f} | "
        f"Critic Loss: {total_critic_loss:.4f} | "
        f"Val Loss: {val_loss:.4f} | "
        f"Acc: {val_accuracy:.4f} | "
        f"Prec: {val_precision:.4f} | "
        f"Rec: {val_recall:.4f} | "
        f"AUROC: {val_auroc:.4f}"
    )

Training: 100%|██████████| 1343/1343 [33:32<00:00,  1.50s/it]
Validation: 100%|██████████| 336/336 [08:34<00:00,  1.53s/it]


New best auroc score: 0.8996
Epoch 1/5 | Train Loss: 0.3671 | Actor Loss: -4.2872 |Critic Loss: 3.7018 |Val Loss: 0.3924 | Acc: 0.8760 | Prec: 0.7881 | Rec: 0.7014 | AUROC: 0.8996


Training: 100%|██████████| 1343/1343 [1:03:15<00:00,  2.83s/it]
Validation: 100%|██████████| 336/336 [08:38<00:00,  1.54s/it]


New best auroc score: 0.9005
Epoch 2/5 | Train Loss: 0.4004 | Actor Loss: 0.3604 |Critic Loss: 65.4283 |Val Loss: 0.3906 | Acc: 0.8759 | Prec: 0.7811 | Rec: 0.7120 | AUROC: 0.9005


Training: 100%|██████████| 1343/1343 [1:05:40<00:00,  2.93s/it]
Validation: 100%|██████████| 336/336 [08:42<00:00,  1.56s/it]


New best auroc score: 0.9008
Epoch 3/5 | Train Loss: 0.3967 | Actor Loss: -2.9447 |Critic Loss: 56.7383 |Val Loss: 0.3901 | Acc: 0.8762 | Prec: 0.7819 | Rec: 0.7118 | AUROC: 0.9008


Training: 100%|██████████| 1343/1343 [1:04:18<00:00,  2.87s/it]
Validation: 100%|██████████| 336/336 [09:01<00:00,  1.61s/it]


New best auroc score: 0.9011
Epoch 4/5 | Train Loss: 0.3955 | Actor Loss: -3.6595 |Critic Loss: 152.5614 |Val Loss: 0.3898 | Acc: 0.8763 | Prec: 0.7830 | Rec: 0.7110 | AUROC: 0.9011


Training: 100%|██████████| 1343/1343 [1:04:28<00:00,  2.88s/it]
Validation: 100%|██████████| 336/336 [08:53<00:00,  1.59s/it]


New best auroc score: 0.9013
Epoch 5/5 | Train Loss: 0.3936 | Actor Loss: -2.8835 |Critic Loss: 222.7858 |Val Loss: 0.3896 | Acc: 0.8763 | Prec: 0.7825 | Rec: 0.7116 | AUROC: 0.9013


In [19]:
def create_submission(model_name=model_name):
    get_model, _ = models_[model_name]
    model = get_model()
    model.load_state_dict(torch.load("/kaggle/working/best_model.pth", map_location=device, weights_only=False))

    model.to(device)
    model.eval()
    
    all_keys = []
    all_probs = []
    
    with torch.inference_mode():
        for key, img in tqdm(test_loader):
            img = img.to(device)

            out = model(img)
            probs = torch.sigmoid(out)

            all_keys.extend(key)
            all_probs.append(probs.cpu().numpy())
            
    all_probs = np.concatenate(all_probs, axis=0)

    return all_keys, all_probs

In [20]:
all_keys, all_probs = create_submission()

100%|██████████| 723/723 [17:26<00:00,  1.45s/it]


In [37]:
threshold = 0.5
binary_ = False
final_submission = pd.DataFrame()
final_submission["Image_name"] = all_keys
if binary_:
    final_submission[label_columns] = (probs_ > threshold).astype(int)
else:
    final_submission[label_columns] = probs_
final_submission.to_csv('/kaggle/working/submission.csv', index=False)

print("="* 70)
print("Final submission file created at /kaggle/working/submission.csv")
print("="* 70)

Final submission file created at /kaggle/working/submission.csv


In [38]:
pd.read_csv("/kaggle/working/submission.csv")

Unnamed: 0,Image_name,Atelectasis,Cardiomegaly,Consolidation,Edema,Enlarged Cardiomediastinum,Fracture,Lung Lesion,Lung Opacity,No Finding,Pleural Effusion,Pleural Other,Pneumonia,Pneumothorax,Support Devices
0,00000005_001_001.jpg,0.220297,0.149363,0.195031,0.123469,0.210709,0.305022,0.233154,0.258701,0.529039,0.193854,0.235400,0.223850,0.256924,0.454828
1,00000005_001_002.jpg,0.225101,0.214264,0.214019,0.186304,0.278467,0.276405,0.255477,0.275641,0.379550,0.217344,0.197508,0.234026,0.256407,0.632703
2,00000005_002_001.jpg,0.720314,0.549541,0.610916,0.405752,0.676551,0.461761,0.292920,0.745153,0.107114,0.633497,0.239078,0.392544,0.475623,0.698820
3,00000005_002_002.jpg,0.691454,0.401686,0.602385,0.266316,0.526638,0.593861,0.363133,0.690712,0.113465,0.611649,0.352090,0.489921,0.658006,0.696855
4,00000007_001_001.jpg,0.586692,0.623392,0.428747,0.492691,0.694068,0.263567,0.216696,0.654815,0.174051,0.480549,0.145823,0.236672,0.230011,0.654003
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
46228,20009235_000_000.jpg,0.275094,0.163307,0.203308,0.161742,0.098608,0.082475,0.346820,0.349803,0.513322,0.266759,0.206782,0.154588,0.200309,0.085841
46229,20009236_000_000.jpg,0.205679,0.128511,0.131184,0.098202,0.087050,0.080408,0.279345,0.292245,0.639288,0.140768,0.156294,0.128068,0.126906,0.079985
46230,20009238_000_000.jpg,0.199245,0.130255,0.133074,0.102610,0.093344,0.076683,0.273192,0.301521,0.619003,0.157510,0.168086,0.113717,0.146211,0.091032
46231,20009240_000_000.jpg,0.346170,0.223896,0.176529,0.144243,0.143350,0.093923,0.204498,0.346055,0.581372,0.239030,0.141905,0.142443,0.108583,0.116798
