In [4]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from torchmetrics.functional import accuracy, recall, precision, auroc
from torch.utils.data import Dataset, DataLoader
import timm
import numpy as np
import pandas as pd
import random
from typing import List
from tqdm import tqdm
from torchmetrics.functional import accuracy, recall, precision, auroc
from PIL import Image, ImageEnhance, ImageOps, ImageFilter
import os
from glob import glob
from sklearn.model_selection import train_test_split

BASE_DIR = "/kaggle/input/grand-xray-slam-division-a"
label_columns = [
    'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum',
    'Fracture', 'Lung Lesion', 'Lung Opacity', 'No Finding', 'Pleural Effusion',
    'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices'
]


def setup_seed(seed=None):
    if seed is None:
        return
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Augmentation Spaces and Utilities

In [3]:
# used to freeze layers in a model
def freeze_all(model):
    for param in model.parameters():
        param.requires_grad=False

def get_resnet18(num_classes=14, fine_tune=True):
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

    if fine_tune:
        freeze_all(model)
        for param in model.layer4.parameters():
            param.requires_grad = True

    model.fc = nn.Linear(512, num_classes)
    return model

def get_resnet34(num_classes=14, fine_tune=True):
    model = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)

    if fine_tune:
        freeze_all(model)
        for param in model.layer4.parameters():
            param.requires_grad = True

    model.fc = nn.Linear(512, num_classes)
    return model

def get_effnetb0(num_classes=14, fine_tune=True):
    model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)

    if fine_tune:
        freeze_all(model)
        for idx in range(6, 8):
            for param in model.features[idx].parameters():
                param.requires_grad = True

    model.classifier = nn.Sequential(
        nn.Dropout(p=0.2, inplace=True),
        nn.Linear(in_features=1280, out_features=num_classes)
    )
    
    return model

def get_convnext_tiny(num_classes=14, fine_tune=True):
    model = models.convnext_tiny(weights=models.ConvNeXt_Tiny_Weights.IMAGENET1K_V1)

    if fine_tune:
        freeze_all(model)
        for param in model.features[6].parameters():
            param.requires_grad = True

    model.classifier[2] = nn.Linear(in_features=768, out_features=num_classes)
    return model

# returns getter function and number of in features of the classifier
models_ = {
    "res18": (get_resnet18, 512),
    "res34": (get_resnet34, 512),
    "effb0": (get_effnetb0, 1280),
    "convnext" : (get_convnext_tiny, 768),
}

In [4]:
# Set up transforms
basic_transforms = transforms.Compose([
    transforms.Resize((224, 224)), # resize to 224x224
    transforms.ToTensor(), # convert to tensor [0,1]
    transforms.Normalize( # normalize with ImageNet stats
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# custom transforms with adaptive magnitude
def shear_x(img, magnitude):
    level = magnitude * 0.3 * random.choice([-1, 1])
    return img.transform(img.size, Image.AFFINE, (1, level, 0, 0, 1, 0))

def shear_y(img, magnitude):
    level = magnitude * 0.3 * random.choice([-1, 1])
    return img.transform(img.size, Image.AFFINE, (1, 0, 0, level, 1, 0))

def translate_x(img, magnitude):
    max_shift = 0.3 * img.size[0]
    level = magnitude * max_shift * random.choice([-1, 1])
    return img.transform(img.size, Image.AFFINE, (1, 0, level, 0, 1, 0))

def translate_y(img, magnitude):
    max_shift = 0.3 * img.size[1]
    level = magnitude * max_shift * random.choice([-1, 1])
    return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, level))

def rotate(img, magnitude):
    degrees = magnitude * 30 * random.choice([-1, 1])
    return img.rotate(degrees)

def contrast(img, magnitude):
    enhancer = ImageEnhance.Contrast(img)
    factor = 1.0 + (magnitude * random.choice([-0.9, 0.9]))
    return enhancer.enhance(factor)

def brightness(img, magnitude):
    enhancer = ImageEnhance.Brightness(img)
    factor = 1.0 + (magnitude * random.choice([-0.9, 0.9]))
    return enhancer.enhance(factor)

def sharpness(img, magnitude):
    enhancer = ImageEnhance.Sharpness(img)
    factor = 1.0 + (magnitude * random.choice([-0.9, 0.9]))
    return enhancer.enhance(factor)

def equalize(img, magnitude=None):
    return ImageOps.equalize(img)

def gaussian_blur(img, magnitude):
    radius = magnitude * 2
    return img.filter(ImageFilter.GaussianBlur(radius))

def identity(img, magnitude=None):
    return img

augmentation_space = [
    shear_x, shear_y, 
    translate_x, translate_y,
    rotate, equalize, 
    contrast, brightness, 
    sharpness, identity
]

In [5]:
# AdaAugment can be used to control augmentation magnitudes and operations

class AdaAugment:

    def __init__(self, rand_m, rand_t):
        self.key_transform = {}
        self.key_magnitude = {}
        self.transforms = augmentation_space
        self.rand_m = rand_m
        self.rand_t = rand_t

    def set(self, keys, m, transform_idx=None):
        for i, key in enumerate(keys):
            self.key_magnitude[key] = m[i].cpu().detach()
        
        if transform_idx is not None:
            for i, key in enumerate(keys):
                self.key_transform[key] = transform_idx[i].cpu().detach()

    def __call__(self, key, img):
        # Get Magnitude for the sample
        m = self.key_magnitude.get(key)
        if m is None:  
            if self.rand_m:
                if random.random() < 0.4:
                    return img  # skip transform
                m = random.random() # select a random magnitude
            else:
                m = 0
        else:
            m = float(m)
    
        # Get transform
        t = self.key_transform.get(key)
        if t is None:
            t = random.choice(self.transforms) if self.rand_t else self.transforms[-1] # if rand_t => select random augementation => else => use identiy as first transformation
        else:
            t = self.transforms[int(t)]

        return t(img, m) # return applied transformation with corresponding magnitude


In [6]:
# Focal Loss for fine tuninng

class FocalLoss(nn.Module):

    def __init__(self, weights:List=None, gamma=2.0, reduction="mean"):
        super().__init__()

        self.weights = weights
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logit, target):
        bce_loss = nn.functional.binary_cross_entropy_with_logits(
            logit, target,
            reduction="none"
        )
        probs = torch.exp(-bce_loss)
        F_loss = self.weights.to(target.device) * (1-probs) ** self.gamma * bce_loss

        if self.reduction == "mean":
            return F_loss.mean()
        elif self.reduction == "none":
            return F_loss   
        else:
            return F_loss.sum()

In [7]:
# Saves per sample information such as previous loss etc.

class ValueMemory:
    def __init__(self):
        """
        Stores the last value per key (no EMA).
        """
        self.values = {}

    def __call__(self, keys, vals):
        """
        keys: list of sample identifiers
        vals: torch.Tensor of shape (len(keys), D)
        Returns: current stored values, previous stored values
        """
        stored_list = []
        new_list = []

        for i, key in enumerate(keys):
            val = vals[i]
            if key not in self.values:
                old = val.clone()  # nothing stored yet → use current
            else:
                old = self.values[key]

            self.values[key] = val  # overwrite with last value

            stored_list.append(old.unsqueeze(0))
            new_list.append(self.values[key].unsqueeze(0))

        stored = torch.cat(stored_list, dim=0)  # previous values
        return stored

    def get(self, key):
        """
        Access the stored value for a single key
        """
        return self.values.get(key, None)

    def get_multi(self, keys):
        """
        Access stored values for multiple keys
        """
        return torch.stack([self.values[k] for k in keys])

### Setting up the AdaAugment Agent
+ Critic => returns value for a state
+ Actor => Parameterizes a beta distribution that is used to sample augmentation magnitudes
+ Controller => Parameterizes a categorical distribution that is used to sample transformations from the augmentation space

In [8]:
class Actor(nn.Module):
    def __init__(self, in_features, hidden, out_features):
        super().__init__()
        self.linear1 = nn.Linear(in_features, hidden)
        self.layer_norm1 = nn.LayerNorm(hidden)
        self.linear2 = nn.Linear(hidden, hidden)
        self.alpha_head = nn.Linear(hidden, out_features)
        self.beta_head = nn.Linear(hidden, out_features)

    def forward(self, x):
        x = torch.relu(self.linear1(self.layer_norm1(x)))
        x = torch.relu(self.linear2(x))
        return torch.softmax(self.alpha_head(x), dim=-1) + 1, torch.softmax(self.beta_head(x), dim=-1) + 1

    def get_dist(self, x):
        alpha, beta = self(x)
        dist = torch.distributions.Beta(alpha, beta)
        return dist


class Critic(nn.Module):

    def __init__(self, in_features, hidden):
        super().__init__()
        self.linear1 = nn.Linear(in_features, hidden)
        self.layer_norm1 = nn.LayerNorm(hidden)
        self.linear2 = nn.Linear(hidden, hidden)
        self.head = nn.Linear(hidden, 1)

    def forward(self, x):
        x = torch.relu(self.linear1(self.layer_norm1(x)))
        x = torch.relu(self.linear2(x))
        return self.head(x)


class Controller(nn.Module):

    def __init__(self, in_features, hidden):
        super().__init__()
        self.linear1 = nn.Linear(in_features, hidden)
        self.layer_norm1 = nn.LayerNorm(hidden)
        self.linear2 = nn.Linear(hidden, hidden)
        self.head = nn.Linear(hidden, len(augmentation_space))

    def forward(self, x):
        x = torch.relu(self.linear1(self.layer_norm1(x)))
        x = torch.relu(self.linear2(x))
        return self.head(x)
    
    def get_dist(self, x):
        out = self(x)
        dist = torch.distributions.Categorical(logits=out)
        return dist


class Agent(nn.Module):

    def __init__(self, in_features, control=False, actor=False):
        super().__init__()
        self.val_memory = ValueMemory()
        self.loss_memory = ValueMemory()

        self.critic = Critic(in_features=in_features, hidden=128)

        self.store_ = {}

        self.control_ = control
        self.actor_ = actor
        if control:
            self.controller = Controller(in_features=in_features, hidden=128)

        if actor:
            self.actor = Actor(in_features=in_features, hidden=128, out_features=1) 

        self.actor_optimizer = torch.optim.Adam(
            params=self.actor.parameters(), lr=3e-5, weight_decay=5e-4
        )
        self.critic_optimizer = torch.optim.Adam(
            params=self.critic.parameters(), lr=3e-5, weight_decay=5e-4
        )

    def action(self, state):
        action_actor = None
        action_controller = None
        if self.actor_:  
            dist = self.actor.get_dist(state.detach())
            action_actor = dist.sample()
            self.store_["action_actor"] = action_actor
            self.store_["dist_actor"] = dist
        if self.control_:
            dist = self.controller.get_dist(state.detach())
            action_controller = dist.sample()
            self.store_["action_controller"] = action_controller
            self.store_["dist_controller"] = dist

        return action_actor, action_controller

    def update(self, key, state, reward):
        value = self.critic(state)
        prev_state = self.val_memory(key, state)
        prev_value = self.critic(prev_state)
        
        with torch.no_grad():
            td_target = reward + 0.99 * value

        if self.actor_ and self.control_:
            dist_actor, action_actor = self.store_["dist_actor"], self.store_["action_actor"]
            dist_control, action_control = self.store_["dist_controller"], self.store_["action_controller"]

            log_prob_actor = dist_actor.log_prob(action_actor)
            log_prob_control = dist_control.log_prob(action_control)

            log_prob = log_prob_actor + log_prob_control  

        elif self.actor_:
            dist, action = self.store_["dist_actor"], self.store_["action_actor"]
            log_prob = dist.log_prob(action)

        elif self.control_:
            dist, action = self.store_["dist_controller"], self.store_["action_controller"]
            log_prob = dist.log_prob(action)

        actor_loss = -(log_prob * (td_target - prev_value.detach())).mean()

        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        critic_loss = torch.nn.functional.mse_loss(td_target, prev_value)
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        return actor_loss, critic_loss

### Dataset
+ Class weights
+ Dataset class

In [11]:
def get_class_weights(df):
    weights = []
    label_columns = [
        'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum',
        'Fracture', 'Lung Lesion', 'Lung Opacity', 'No Finding', 'Pleural Effusion',
        'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices'
    ]

    for label in label_columns:
        weight = len(df) / (df[label].sum() + 1e-6)
        weights.append(weight)

    return weights

In [12]:
class XRayDataset(Dataset):

    def __init__(self, df, train=True, transform=None):
        super().__init__()
        self.df = df
        self.train = train
        self.transform = transform
        self.label_columns = [
            'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum',
            'Fracture', 'Lung Lesion', 'Lung Opacity', 'No Finding', 'Pleural Effusion',
            'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices'
        ]
        if train:
            class_weights = np.array(get_class_weights(df), dtype=np.float32)
    
            sample_weights = (df[label_columns].values.astype(np.float32) * class_weights).mean(axis=1)
            sample_weights[sample_weights == 0] = 1.0
            self.df["weight"] = sample_weights

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = self.df.iloc[idx]["img_path"]
        img = Image.open(img_path).convert("RGB")
        label = self.df.iloc[idx][self.label_columns].values.astype(np.float32)
        key = self.df.iloc[idx]["Image_name"]
        
        if self.train and self.transform is not None:
            img = self.transform(key, img)

        img = basic_transforms(img)

        if not self.train:
            return key, img

        return key, img, torch.tensor(label, dtype=torch.float32)

### Adapt entire Model to new dataset

In [13]:
train_csv = pd.read_csv("/kaggle/input/grand-xray-slam-division-a/train1.csv")
train_csv["img_path"] = train_csv["Image_name"].apply(lambda x: os.path.join(BASE_DIR, "train1", x))
submission_df = pd.read_csv("/kaggle/input/grand-xray-slam-division-a/sample_submission_1.csv")
train_df, val_df = train_test_split(
    train_csv,
    test_size=0.2,
    random_state=42,
    stratify=train_csv["No Finding"]
)

submission_df["img_path"] = submission_df["Image_name"].apply(lambda x : os.path.join(BASE_DIR, "test1", x))
class_weights = get_class_weights(train_df)
zipped_class_weights = list(zip(label_columns, class_weights))

for class_, weight in zipped_class_weights:
    print(f"{class_}: {weight:.3f}\n")

Atelectasis: 2.765

Cardiomegaly: 3.073

Consolidation: 3.666

Edema: 4.034

Enlarged Cardiomediastinum: 2.841

Fracture: 7.221

Lung Lesion: 9.104

Lung Opacity: 2.211

No Finding: 3.162

Pleural Effusion: 3.137

Pleural Other: 15.225

Pneumonia: 7.571

Pneumothorax: 12.202

Support Devices: 2.852



In [14]:
# Dataset
train_set = XRayDataset(train_df, train=True)
val_set = XRayDataset(val_df, train=True)
test_set = XRayDataset(submission_df, train=False)

# DataLoader
batch_size = 32
train_loader = DataLoader(
    train_set, 
    batch_size=batch_size, 
    shuffle=True, 
    pin_memory=True,
    num_workers=os.cpu_count()
)
val_loader = DataLoader(
    val_set, 
    batch_size=batch_size, 
    pin_memory=True,
    num_workers=os.cpu_count()
)
test_loader = DataLoader(
    test_set, 
    batch_size=batch_size, 
    pin_memory=True,
    num_workers=os.cpu_count()
)

In [15]:
# Models and functions
model_name = "res18"
get_model, in_features = models_[model_name]
model = get_effnetb0(fine_tune=False)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
bce_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(class_weights).to(device))

Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth
100%|██████████| 20.5M/20.5M [00:00<00:00, 131MB/s]


In [16]:
epochs = 3
model.to(device)
best_score = 0

if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")
    model = nn.DataParallel(model)
    
for epoch in range(epochs):
    
    # Training
    
    model.train()
    train_loss = 0
    for keys, imgs, targets in tqdm(train_loader, desc="Training"):
        imgs, targets = imgs.to(device, non_blocking=True), targets.to(device, non_blocking=True)

        logits = model(imgs)
        loss = bce_loss(logits, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
    train_loss /= len(train_loader)

    # Validation
    model.eval()
    val_loss = 0

    all_probs = []
    all_targets = []

    with torch.inference_mode():
        for keys, imgs, targets in tqdm(val_loader, desc="Validation"):
            imgs, targets = imgs.to(device, non_blocking=True), targets.to(device, non_blocking=True)

            logits = model(imgs)
            loss = bce_loss(logits, targets)
            probs = torch.sigmoid(logits)

            val_loss += loss.item()

            all_probs.append(probs)
            all_targets.append(targets)

    val_loss /= len(val_loader)
    all_probs = torch.cat(all_probs, dim=0)
    all_targets = torch.cat(all_targets, dim=0)

    val_accuracy  = accuracy(all_probs, all_targets.int(), task="multilabel", num_labels=14)
    val_precision = precision(all_probs, all_targets.int(), task="multilabel", num_labels=14)
    val_recall    = recall(all_probs, all_targets.int(), task="multilabel", num_labels=14)
    val_auroc     = auroc(all_probs, all_targets.int(), task="multilabel", num_labels=14)

    if val_auroc > best_score:
        best_score = val_auroc
        initial_best_model = model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict()
        torch.save(initial_best_model, "intial_best_model.pth")
        print(f"New best auroc score: {best_score:.4f}")

    print(
        f"Epoch {epoch+1}/{epochs} | "
        f"Train Loss: {train_loss:.4f} | "
        f"Val Loss: {val_loss:.4f} | "
        f"Acc: {val_accuracy:.4f} | "
        f"Prec: {val_precision:.4f} | "
        f"Rec: {val_recall:.4f} | "
        f"AUROC: {val_auroc:.4f}"
    )

Training:   0%|          | 0/2685 [00:18<?, ?it/s]


KeyboardInterrupt: 

### Finetuning using AdaAugment and Focal Loss

In [None]:
randm, randt = True, False

ada_augment=AdaAugment(rand_m=randm, rand_t=randt)
train_set = XRayDataset(train_df, train=True, transform=ada_augment)

w_ = torch.tensor(train_set.df["weight"], dtype=torch.float32)
sampler = torch.utils.data.WeightedRandomSampler(
    weights=w_,
    num_samples=len(w_),
    replacement=True
)

# New dataset using adaptive augmentations
train_loader = DataLoader(
    train_set, 
    batch_size=batch_size,
    #sampler=sampler,
    shuffle=True, 
    pin_memory=True,
    num_workers=os.cpu_count()
)

In [None]:
get_model, in_features = models_[model_name]

model = get_model(fine_tune=True)

model.load_state_dict(torch.load("initial_best_model.pth", map_location=device, weights_only=False))

optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-5)

loss_fn = FocalLoss(weights=class_weights, reduction="none")

agent = Agent(in_features, not randt, not randm)
state = {}
best_score = 0

In [None]:
def _register_head_hook(model):
    def hook(module, input, output):
        state['head_input'] = input[0].detach()
        
    _hook_handle = model.classifier[-1].register_forward_hook(hook)
    return _hook_handle
    
def _remove_head_hook(_hook_handle):
    _hook_handle.remove()

In [None]:
epochs = 5
r_weight = 1
model.to(device)


#### use forward hook here

if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")
    model = nn.DataParallel(model)
    
for epoch in range(epochs):
    # Training
    
    model.train()
    handle = _register_head_hook()
    
    train_loss, total_actor_loss, total_critic_loss = 0, 0, 0
    for keys, imgs, targets in tqdm(train_loader, desc="Training"):
        imgs, targets = imgs.to(device, non_blocking=True), targets.to(device, non_blocking=True)

        logits = model(imgs)
        state = state["head_input"]

        loss = bce_loss(logits, targets)
        prev_loss = agent.loss_memory(keys, loss.detach())
        probs = torch.sigmoid(logits)

        action, transform_idx = agent.action(state)
        ada_augment.set(keys, action, transform_idx)

        entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=1).unsqueeze(1)
        reward = r_weight = (loss.detach().mean() - prev_loss.mean()) + (1 - r_weight) * entropy

        actor_loss, critic_loss = agent.update(keys, state, reward)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        total_actor_loss += actor_loss.item()
        total_critic_loss += critic_loss.item()
    train_loss /= len(train_loader)
    total_actor_loss /= len(train_loader)
    total_critic_loss /= len(train_loader)

    _remove_head_hook(handle)
    # Validation
    model.eval()
    val_loss = 0

    all_probs = []
    all_targets = []

    with torch.inference_mode():
        for keys, imgs, targets in tqdm(val_loader, desc="Validation"):
            imgs, targets = imgs.to(device, non_blocking=True), targets.to(device, non_blocking=True)

            logits = model(imgs)
            loss = bce_loss(logits, targets)
            probs = torch.sigmoid(logits)

            val_loss += loss.item()

            all_probs.append(probs)
            all_targets.append(targets)

    val_loss /= len(val_loader)
    all_probs = torch.cat(all_probs, dim=0)
    all_targets = torch.cat(all_targets, dim=0)

    val_accuracy  = accuracy(all_probs, all_targets.int(), task="multilabel", num_labels=14)
    val_precision = precision(all_probs, all_targets.int(), task="multilabel", num_labels=14)
    val_recall = recall(all_probs, all_targets.int(), task="multilabel", num_labels=14)
    val_auroc = auroc(all_probs, all_targets.int(), task="multilabel", num_labels=14)

    if val_auroc > best_score:
        best_score = val_auroc
        initial_best_model = model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict()
        torch.save(initial_best_model, "best_model.pth")
        print(f"New best auroc score: {best_score:.4f}")

    print(
        f"Epoch {epoch+1}/{epochs} | "
        f"Train Loss: {train_loss:.4f} | "
        f"Actor Loss: {total_actor_loss:.4f} |"
        f"Critic LOss: {total_critic_loss:.4f} |"
        f"Val Loss: {val_loss:.4f} | "
        f"Acc: {val_accuracy:.4f} | "
        f"Prec: {val_precision:.4f} | "
        f"Rec: {val_recall:.4f} | "
        f"AUROC: {val_auroc:.4f}"
    )

### Create submission using the best model

In [None]:
def create_submission(model_name=model_name):
    get_model, _ = models_[model_name]
    model = get_model()
    model.load_state_dict(torch.load("best_model.pth", map_location=device, weights_only=False))

    model.eval()
    
    all_keys = []
    all_probs = []
    
    with torch.inference_mode():
        for key, img in tqdm(test_loader):
            img = img.to(device)

            out = model(img)
            probs = torch.sigmoid(out)

            all_keys.extend(key)
            all_probs.append(probs.cpu().numpy())
            
    all_probs = np.concat(all_probs, axis=0)

    return all_keys, all_probs

In [None]:
all_keys, all_probs = create_submission()

threshold = 0.5

final_submission_= pd.DataFrame()
final_submission["Image_name"] = all_keys
final_submission[label_columns] = np.int(all_probs > threshold.astype(int)
final_submission.to_csv('/kaggle/working/submission.csv', index=False)

print("="* 60)
print("Final submission file created at /kaggle/working/submission.csv")
print("="* 60)