In [None]:
import os
import numpy as np 
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, roc_curve

import torch
from torch import nn, optim
from torchvision import models
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [None]:
class CheXpertDataset(Dataset):
    def __init__(self, data, root_dir, mode='train', transforms=None):
        self.data = data.to_numpy()
        self.labels = torch.tensor(data.values)
        self.root_dir = root_dir
        self.img_paths = [os.path.join(root_dir, img_path) for img_path in data.index]
        self.transform = transforms.get(mode)
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image=np.array(image))['image']

        return (image, label)
    
def get_weighted_random_sampler(data, batch_size):
    weights = 1/data.sum()
    sampler = WeightedRandomSampler(weights, batch_size, replacement=True)
    return sampler

transform = {
    'train': A.Compose([
        A.Affine(scale=(0.95, 1.05), p=0.5),
        A.OneOf([A.Affine(rotate=(-20, 20), p=0.5), A.Affine(shear=(-5, 5), p=0.5)], p=0.5),
        A.Affine(translate_percent=(-0.05, 0.05), p=0.5),
        A.Resize(224, 224),
        A.Normalize([0.506, 0.506, 0.506], [0.287, 0.287, 0.287]),
        ToTensorV2()
    ]),
    'val': A.Compose([
        A.Resize(224, 224),
        A.Normalize([0.506, 0.506, 0.506], [0.287, 0.287, 0.287]),
        ToTensorV2()
    ]),
}

In [101]:
test = pd.read_csv('/kaggle/input/chexpertsmallclean/u0_test.csv',index_col=0)
train = pd.read_csv('/kaggle/input/chexpertsmallclean/u0_train.csv', index_col=0)
val = pd.read_csv('/kaggle/input/chexpertsmallclean/u0_val.csv', index_col=0)

In [102]:
train.index =  train.index.str.replace('CheXpert-v1.0-small', 'chexpert')
test.index = test.index.str.replace('CheXpert-v1.0-small', 'chexpert')
val.index = val.index.str.replace('CheXpert-v1.0-small', 'chexpert')

In [103]:
batch_size = 32
train_dataset = CheXpertDataset(
    data=train,
    root_dir='/kaggle/input/', 
    mode='train',
    transforms = transform
    )
train_loader = DataLoader(
    train_dataset,
    batch_size = batch_size,
    sampler = get_weighted_random_sampler(train, batch_size),
    num_workers = 4,
    pin_memory=True
    )

  weights_tensor = torch.as_tensor(weights, dtype=torch.double)


In [104]:
pretrained_densenet = models.densenet121(pretrained=True)

for param in pretrained_densenet.features.parameters():
    param.requires_grad = False

# Define a new classifier with additional layers
class CustomDenseNet(nn.Module):
    def __init__(self, pretrained_model, num_classes):
        super(CustomDenseNet, self).__init__()
        
        # Retain the feature extraction part of the pre-trained DenseNet
        self.features = pretrained_model.features
        
        # Add custom convolutional layers after the DenseNet feature extractor
        self.additional_conv = nn.Sequential(
            nn.Conv2d(1024, 512, kernel_size=3, padding=1),  # Example additional convolutional layer
            nn.ReLU(),
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.ReLU()
        )
        
        # Global average pooling (to reduce spatial dimensions before the classifier)
        self.pool = nn.AdaptiveAvgPool2d(1)

        # Final fully connected layer (classifier), adjusted for multi-label classification
        self.classifier = nn.Linear(256, num_classes)  # 128 is the output channels from the conv layers
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        # Pass through the DenseNet feature extractor
        x = self.features(x)
        
        # Pass through the additional convolutional layers
        x = self.additional_conv(x)
        
        # Apply global average pooling
        x = self.pool(x)
        
        # Flatten the tensor and pass it through the classifier
        x = torch.flatten(x, 1)  # Flatten the output from the pooling layer
        x = self.classifier(x)
        return x

# Replace the DenseNet's classifier with the custom classifier
num_classes = 14
model = CustomDenseNet(pretrained_densenet, num_classes)



In [105]:
model = models.densenet121(pretrained=True)
for param in model.parameters():
    param.requires_grad = False
model.classifier = nn.Linear(in_features=model.classifier.in_features, out_features=14)

In [106]:
loss_function = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4, betas=(0.9,0.999), weight_decay=1e-4)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 
                    factor = 0.1, patience = 5, mode = 'max', verbose=True)



In [107]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

num_epochs = 200

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
# Wrap the training data loader with tqdm for a progress bar
    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch"):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
# Forward pass: Get the logits (raw scores) from the model
        outputs = model(inputs)
# Calculate the loss: BCEWithLogitsLoss expects logits, not probabilities
        loss = loss_function(outputs, labels)
# Backward pass and optimization step
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
# Print loss after each epoch
    
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}")

Epoch 1/200: 100%|██████████| 1/1 [00:00<00:00,  2.20batch/s]


Epoch 1/200, Loss: 0.7634506974635379


Epoch 2/200: 100%|██████████| 1/1 [00:00<00:00,  2.38batch/s]


Epoch 2/200, Loss: 0.7521307688075467


Epoch 3/200: 100%|██████████| 1/1 [00:00<00:00,  2.39batch/s]


Epoch 3/200, Loss: 0.7236443686010066


Epoch 4/200: 100%|██████████| 1/1 [00:00<00:00,  2.28batch/s]


Epoch 4/200, Loss: 0.7096581289442838


Epoch 5/200: 100%|██████████| 1/1 [00:00<00:00,  2.35batch/s]


Epoch 5/200, Loss: 0.6916129885108343


Epoch 6/200: 100%|██████████| 1/1 [00:00<00:00,  2.38batch/s]


Epoch 6/200, Loss: 0.6898247535843568


Epoch 7/200: 100%|██████████| 1/1 [00:00<00:00,  2.39batch/s]


Epoch 7/200, Loss: 0.6799509549954174


Epoch 8/200: 100%|██████████| 1/1 [00:00<00:00,  2.31batch/s]


Epoch 8/200, Loss: 0.663703754745906


Epoch 9/200: 100%|██████████| 1/1 [00:00<00:00,  2.34batch/s]


Epoch 9/200, Loss: 0.6428269501255791


Epoch 10/200: 100%|██████████| 1/1 [00:00<00:00,  2.35batch/s]


Epoch 10/200, Loss: 0.6228507972856667


Epoch 11/200: 100%|██████████| 1/1 [00:00<00:00,  2.37batch/s]


Epoch 11/200, Loss: 0.623840610409908


Epoch 12/200: 100%|██████████| 1/1 [00:00<00:00,  2.27batch/s]


Epoch 12/200, Loss: 0.6119996003856483


Epoch 13/200: 100%|██████████| 1/1 [00:00<00:00,  2.18batch/s]


Epoch 13/200, Loss: 0.5853431463786234


Epoch 14/200: 100%|██████████| 1/1 [00:00<00:00,  2.20batch/s]


Epoch 14/200, Loss: 0.5835973619240186


Epoch 15/200: 100%|██████████| 1/1 [00:00<00:00,  2.34batch/s]


Epoch 15/200, Loss: 0.569165155416288


Epoch 16/200: 100%|██████████| 1/1 [00:00<00:00,  2.31batch/s]


Epoch 16/200, Loss: 0.5524335000809515


Epoch 17/200: 100%|██████████| 1/1 [00:00<00:00,  2.28batch/s]


Epoch 17/200, Loss: 0.5458217977803932


Epoch 18/200: 100%|██████████| 1/1 [00:00<00:00,  2.27batch/s]


Epoch 18/200, Loss: 0.5456957468663209


Epoch 19/200: 100%|██████████| 1/1 [00:00<00:00,  2.33batch/s]


Epoch 19/200, Loss: 0.5321343794722841


Epoch 20/200: 100%|██████████| 1/1 [00:00<00:00,  2.20batch/s]


Epoch 20/200, Loss: 0.5258687351140127


Epoch 21/200: 100%|██████████| 1/1 [00:00<00:00,  2.31batch/s]


Epoch 21/200, Loss: 0.5173384170422131


Epoch 22/200: 100%|██████████| 1/1 [00:00<00:00,  2.29batch/s]


Epoch 22/200, Loss: 0.49467154826353565


Epoch 23/200: 100%|██████████| 1/1 [00:00<00:00,  2.34batch/s]


Epoch 23/200, Loss: 0.4909330703854461


Epoch 24/200: 100%|██████████| 1/1 [00:00<00:00,  2.29batch/s]


Epoch 24/200, Loss: 0.4838561508981261


Epoch 25/200: 100%|██████████| 1/1 [00:00<00:00,  2.30batch/s]


Epoch 25/200, Loss: 0.4881284194083751


Epoch 26/200: 100%|██████████| 1/1 [00:00<00:00,  2.31batch/s]


Epoch 26/200, Loss: 0.4807858919292422


Epoch 27/200: 100%|██████████| 1/1 [00:00<00:00,  2.35batch/s]


Epoch 27/200, Loss: 0.473500394000439


Epoch 28/200: 100%|██████████| 1/1 [00:00<00:00,  2.34batch/s]


Epoch 28/200, Loss: 0.4590703301905055


Epoch 29/200: 100%|██████████| 1/1 [00:00<00:00,  2.32batch/s]


Epoch 29/200, Loss: 0.4588334993209823


Epoch 30/200: 100%|██████████| 1/1 [00:00<00:00,  2.30batch/s]


Epoch 30/200, Loss: 0.4471905061343152


Epoch 31/200: 100%|██████████| 1/1 [00:00<00:00,  2.31batch/s]


Epoch 31/200, Loss: 0.45262142945596545


Epoch 32/200: 100%|██████████| 1/1 [00:00<00:00,  2.23batch/s]


Epoch 32/200, Loss: 0.44240483616886195


Epoch 33/200: 100%|██████████| 1/1 [00:00<00:00,  2.28batch/s]


Epoch 33/200, Loss: 0.44067504934043555


Epoch 34/200: 100%|██████████| 1/1 [00:00<00:00,  2.32batch/s]


Epoch 34/200, Loss: 0.4289263678282233


Epoch 35/200: 100%|██████████| 1/1 [00:00<00:00,  2.34batch/s]


Epoch 35/200, Loss: 0.4159046768348032


Epoch 36/200: 100%|██████████| 1/1 [00:00<00:00,  2.28batch/s]


Epoch 36/200, Loss: 0.41978784962806714


Epoch 37/200: 100%|██████████| 1/1 [00:00<00:00,  2.20batch/s]


Epoch 37/200, Loss: 0.41026522617072


Epoch 38/200: 100%|██████████| 1/1 [00:00<00:00,  2.35batch/s]


Epoch 38/200, Loss: 0.4193283374791333


Epoch 39/200: 100%|██████████| 1/1 [00:00<00:00,  2.35batch/s]


Epoch 39/200, Loss: 0.40531744702561157


Epoch 40/200: 100%|██████████| 1/1 [00:00<00:00,  2.21batch/s]


Epoch 40/200, Loss: 0.3973387530173308


Epoch 41/200: 100%|██████████| 1/1 [00:00<00:00,  2.26batch/s]


Epoch 41/200, Loss: 0.39091729621369653


Epoch 42/200: 100%|██████████| 1/1 [00:00<00:00,  2.31batch/s]


Epoch 42/200, Loss: 0.3911705668932492


Epoch 43/200: 100%|██████████| 1/1 [00:00<00:00,  2.38batch/s]


Epoch 43/200, Loss: 0.3905557704690311


Epoch 44/200: 100%|██████████| 1/1 [00:00<00:00,  2.30batch/s]


Epoch 44/200, Loss: 0.3813494179630652


Epoch 45/200: 100%|██████████| 1/1 [00:00<00:00,  2.34batch/s]


Epoch 45/200, Loss: 0.3631508545965646


Epoch 46/200: 100%|██████████| 1/1 [00:00<00:00,  2.36batch/s]


Epoch 46/200, Loss: 0.3806333666746338


Epoch 47/200: 100%|██████████| 1/1 [00:00<00:00,  2.34batch/s]


Epoch 47/200, Loss: 0.3800947249567668


Epoch 48/200: 100%|██████████| 1/1 [00:00<00:00,  2.32batch/s]


Epoch 48/200, Loss: 0.36852411748675096


Epoch 49/200: 100%|██████████| 1/1 [00:00<00:00,  2.24batch/s]


Epoch 49/200, Loss: 0.3775127913021216


Epoch 50/200: 100%|██████████| 1/1 [00:00<00:00,  2.23batch/s]


Epoch 50/200, Loss: 0.37607134467439857


Epoch 51/200: 100%|██████████| 1/1 [00:00<00:00,  2.17batch/s]


Epoch 51/200, Loss: 0.3690161001328046


Epoch 52/200: 100%|██████████| 1/1 [00:00<00:00,  2.07batch/s]


Epoch 52/200, Loss: 0.37634925461939667


Epoch 53/200: 100%|██████████| 1/1 [00:00<00:00,  2.13batch/s]


Epoch 53/200, Loss: 0.36023952525075786


Epoch 54/200: 100%|██████████| 1/1 [00:00<00:00,  2.32batch/s]


Epoch 54/200, Loss: 0.3555350973287464


Epoch 55/200: 100%|██████████| 1/1 [00:00<00:00,  2.32batch/s]


Epoch 55/200, Loss: 0.35585950132683913


Epoch 56/200: 100%|██████████| 1/1 [00:00<00:00,  2.22batch/s]


Epoch 56/200, Loss: 0.3523108626416485


Epoch 57/200: 100%|██████████| 1/1 [00:00<00:00,  2.33batch/s]


Epoch 57/200, Loss: 0.34254216332088355


Epoch 58/200: 100%|██████████| 1/1 [00:00<00:00,  2.38batch/s]


Epoch 58/200, Loss: 0.3467219790659978


Epoch 59/200: 100%|██████████| 1/1 [00:00<00:00,  2.32batch/s]


Epoch 59/200, Loss: 0.3561273782979697


Epoch 60/200: 100%|██████████| 1/1 [00:00<00:00,  2.19batch/s]


Epoch 60/200, Loss: 0.34764942925955566


Epoch 61/200: 100%|██████████| 1/1 [00:00<00:00,  2.33batch/s]


Epoch 61/200, Loss: 0.323338202710147


Epoch 62/200: 100%|██████████| 1/1 [00:00<00:00,  2.27batch/s]


Epoch 62/200, Loss: 0.3378201234528595


Epoch 63/200: 100%|██████████| 1/1 [00:00<00:00,  2.35batch/s]


Epoch 63/200, Loss: 0.34115333050223334


Epoch 64/200: 100%|██████████| 1/1 [00:00<00:00,  2.22batch/s]


Epoch 64/200, Loss: 0.3130598224992614


Epoch 65/200: 100%|██████████| 1/1 [00:00<00:00,  2.32batch/s]


Epoch 65/200, Loss: 0.33715286010660095


Epoch 66/200: 100%|██████████| 1/1 [00:00<00:00,  2.39batch/s]


Epoch 66/200, Loss: 0.3298080482865251


Epoch 67/200: 100%|██████████| 1/1 [00:00<00:00,  2.36batch/s]


Epoch 67/200, Loss: 0.32716901434053264


Epoch 68/200: 100%|██████████| 1/1 [00:00<00:00,  2.22batch/s]


Epoch 68/200, Loss: 0.32402868026318693


Epoch 69/200: 100%|██████████| 1/1 [00:00<00:00,  2.36batch/s]


Epoch 69/200, Loss: 0.330191009014795


Epoch 70/200: 100%|██████████| 1/1 [00:00<00:00,  2.31batch/s]


Epoch 70/200, Loss: 0.3316692516631779


Epoch 71/200: 100%|██████████| 1/1 [00:00<00:00,  2.35batch/s]


Epoch 71/200, Loss: 0.3305051588270414


Epoch 72/200: 100%|██████████| 1/1 [00:00<00:00,  2.22batch/s]


Epoch 72/200, Loss: 0.3161519849839221


Epoch 73/200: 100%|██████████| 1/1 [00:00<00:00,  2.36batch/s]


Epoch 73/200, Loss: 0.3340665170773198


Epoch 74/200: 100%|██████████| 1/1 [00:00<00:00,  2.39batch/s]


Epoch 74/200, Loss: 0.33184423788147144


Epoch 75/200: 100%|██████████| 1/1 [00:00<00:00,  2.33batch/s]


Epoch 75/200, Loss: 0.3178455898471709


Epoch 76/200: 100%|██████████| 1/1 [00:00<00:00,  2.22batch/s]


Epoch 76/200, Loss: 0.31588892140799935


Epoch 77/200: 100%|██████████| 1/1 [00:00<00:00,  2.39batch/s]


Epoch 77/200, Loss: 0.31147857567495


Epoch 78/200: 100%|██████████| 1/1 [00:00<00:00,  2.38batch/s]


Epoch 78/200, Loss: 0.30971822890569456


Epoch 79/200: 100%|██████████| 1/1 [00:00<00:00,  2.39batch/s]


Epoch 79/200, Loss: 0.3330648137905103


Epoch 80/200: 100%|██████████| 1/1 [00:00<00:00,  2.23batch/s]


Epoch 80/200, Loss: 0.3254487578732161


Epoch 81/200: 100%|██████████| 1/1 [00:00<00:00,  2.35batch/s]


Epoch 81/200, Loss: 0.32900165295827066


Epoch 82/200: 100%|██████████| 1/1 [00:00<00:00,  2.30batch/s]


Epoch 82/200, Loss: 0.30541867165343967


Epoch 83/200: 100%|██████████| 1/1 [00:00<00:00,  2.30batch/s]


Epoch 83/200, Loss: 0.3073282706027385


Epoch 84/200: 100%|██████████| 1/1 [00:00<00:00,  2.24batch/s]


Epoch 84/200, Loss: 0.29808778651723905


Epoch 85/200: 100%|██████████| 1/1 [00:00<00:00,  2.34batch/s]


Epoch 85/200, Loss: 0.30503453153401744


Epoch 86/200: 100%|██████████| 1/1 [00:00<00:00,  2.30batch/s]


Epoch 86/200, Loss: 0.31965666666759973


Epoch 87/200: 100%|██████████| 1/1 [00:00<00:00,  2.28batch/s]


Epoch 87/200, Loss: 0.30957039383273305


Epoch 88/200: 100%|██████████| 1/1 [00:00<00:00,  2.22batch/s]


Epoch 88/200, Loss: 0.31384222031920217


Epoch 89/200: 100%|██████████| 1/1 [00:00<00:00,  2.21batch/s]


Epoch 89/200, Loss: 0.30537077412425007


Epoch 90/200: 100%|██████████| 1/1 [00:00<00:00,  2.30batch/s]


Epoch 90/200, Loss: 0.2996755488150354


Epoch 91/200: 100%|██████████| 1/1 [00:00<00:00,  2.33batch/s]


Epoch 91/200, Loss: 0.2982692801825968


Epoch 92/200: 100%|██████████| 1/1 [00:00<00:00,  2.24batch/s]


Epoch 92/200, Loss: 0.28194907483079334


Epoch 93/200: 100%|██████████| 1/1 [00:00<00:00,  2.34batch/s]


Epoch 93/200, Loss: 0.27920661351111314


Epoch 94/200: 100%|██████████| 1/1 [00:00<00:00,  2.36batch/s]


Epoch 94/200, Loss: 0.28335757275843726


Epoch 95/200: 100%|██████████| 1/1 [00:00<00:00,  2.34batch/s]


Epoch 95/200, Loss: 0.2936495325619554


Epoch 96/200: 100%|██████████| 1/1 [00:00<00:00,  2.23batch/s]


Epoch 96/200, Loss: 0.29392521758563817


Epoch 97/200: 100%|██████████| 1/1 [00:00<00:00,  2.35batch/s]


Epoch 97/200, Loss: 0.28615849766980056


Epoch 98/200: 100%|██████████| 1/1 [00:00<00:00,  2.40batch/s]


Epoch 98/200, Loss: 0.29066784107791527


Epoch 99/200: 100%|██████████| 1/1 [00:00<00:00,  2.33batch/s]


Epoch 99/200, Loss: 0.28455409062943154


Epoch 100/200: 100%|██████████| 1/1 [00:00<00:00,  2.28batch/s]


Epoch 100/200, Loss: 0.29051801493291607


Epoch 101/200: 100%|██████████| 1/1 [00:00<00:00,  2.38batch/s]


Epoch 101/200, Loss: 0.32703454666105763


Epoch 102/200: 100%|██████████| 1/1 [00:00<00:00,  2.33batch/s]


Epoch 102/200, Loss: 0.28930828299988726


Epoch 103/200: 100%|██████████| 1/1 [00:00<00:00,  2.35batch/s]


Epoch 103/200, Loss: 0.29924950025451835


Epoch 104/200: 100%|██████████| 1/1 [00:00<00:00,  2.22batch/s]


Epoch 104/200, Loss: 0.28196219650895465


Epoch 105/200: 100%|██████████| 1/1 [00:00<00:00,  2.25batch/s]


Epoch 105/200, Loss: 0.29666474720085845


Epoch 106/200: 100%|██████████| 1/1 [00:00<00:00,  2.36batch/s]


Epoch 106/200, Loss: 0.33060446242078406


Epoch 107/200: 100%|██████████| 1/1 [00:00<00:00,  2.36batch/s]


Epoch 107/200, Loss: 0.2668251295773578


Epoch 108/200: 100%|██████████| 1/1 [00:00<00:00,  2.27batch/s]


Epoch 108/200, Loss: 0.3242369962099474


Epoch 109/200: 100%|██████████| 1/1 [00:00<00:00,  2.41batch/s]


Epoch 109/200, Loss: 0.3447330784630529


Epoch 110/200: 100%|██████████| 1/1 [00:00<00:00,  2.40batch/s]


Epoch 110/200, Loss: 0.274624711675902


Epoch 111/200: 100%|██████████| 1/1 [00:00<00:00,  2.38batch/s]


Epoch 111/200, Loss: 0.2912338209992283


Epoch 112/200: 100%|██████████| 1/1 [00:00<00:00,  2.24batch/s]


Epoch 112/200, Loss: 0.2898394584003004


Epoch 113/200: 100%|██████████| 1/1 [00:00<00:00,  2.38batch/s]


Epoch 113/200, Loss: 0.2697048642606075


Epoch 114/200: 100%|██████████| 1/1 [00:00<00:00,  2.41batch/s]


Epoch 114/200, Loss: 0.28718038954996566


Epoch 115/200: 100%|██████████| 1/1 [00:00<00:00,  2.26batch/s]


Epoch 115/200, Loss: 0.27687541869818233


Epoch 116/200: 100%|██████████| 1/1 [00:00<00:00,  2.24batch/s]


Epoch 116/200, Loss: 0.2908467692738798


Epoch 117/200: 100%|██████████| 1/1 [00:00<00:00,  2.37batch/s]


Epoch 117/200, Loss: 0.28196175245102495


Epoch 118/200: 100%|██████████| 1/1 [00:00<00:00,  2.34batch/s]


Epoch 118/200, Loss: 0.2824436120489346


Epoch 119/200: 100%|██████████| 1/1 [00:00<00:00,  2.39batch/s]


Epoch 119/200, Loss: 0.27160927535234286


Epoch 120/200: 100%|██████████| 1/1 [00:00<00:00,  2.27batch/s]


Epoch 120/200, Loss: 0.25929548357505283


Epoch 121/200: 100%|██████████| 1/1 [00:00<00:00,  2.39batch/s]


Epoch 121/200, Loss: 0.27652699400953545


Epoch 122/200: 100%|██████████| 1/1 [00:00<00:00,  2.36batch/s]


Epoch 122/200, Loss: 0.2670192989919867


Epoch 123/200: 100%|██████████| 1/1 [00:00<00:00,  2.08batch/s]


Epoch 123/200, Loss: 0.2462742446971658


Epoch 124/200: 100%|██████████| 1/1 [00:00<00:00,  1.99batch/s]


Epoch 124/200, Loss: 0.2719132738337586


Epoch 125/200: 100%|██████████| 1/1 [00:00<00:00,  2.20batch/s]


Epoch 125/200, Loss: 0.2812457328130092


Epoch 126/200: 100%|██████████| 1/1 [00:00<00:00,  2.29batch/s]


Epoch 126/200, Loss: 0.2726858685990529


Epoch 127/200: 100%|██████████| 1/1 [00:00<00:00,  2.35batch/s]


Epoch 127/200, Loss: 0.26235618609644007


Epoch 128/200: 100%|██████████| 1/1 [00:00<00:00,  2.17batch/s]


Epoch 128/200, Loss: 0.2808782108726778


Epoch 129/200: 100%|██████████| 1/1 [00:00<00:00,  2.37batch/s]


Epoch 129/200, Loss: 0.28533889045190464


Epoch 130/200: 100%|██████████| 1/1 [00:00<00:00,  2.34batch/s]


Epoch 130/200, Loss: 0.26680029783996617


Epoch 131/200: 100%|██████████| 1/1 [00:00<00:00,  2.35batch/s]


Epoch 131/200, Loss: 0.2669154507717134


Epoch 132/200: 100%|██████████| 1/1 [00:00<00:00,  2.25batch/s]


Epoch 132/200, Loss: 0.272154176433105


Epoch 133/200: 100%|██████████| 1/1 [00:00<00:00,  2.39batch/s]


Epoch 133/200, Loss: 0.266944473625959


Epoch 134/200: 100%|██████████| 1/1 [00:00<00:00,  2.39batch/s]


Epoch 134/200, Loss: 0.26904085385779447


Epoch 135/200: 100%|██████████| 1/1 [00:00<00:00,  2.36batch/s]


Epoch 135/200, Loss: 0.2550867799602981


Epoch 136/200: 100%|██████████| 1/1 [00:00<00:00,  2.27batch/s]


Epoch 136/200, Loss: 0.24837212647045295


Epoch 137/200: 100%|██████████| 1/1 [00:00<00:00,  2.34batch/s]


Epoch 137/200, Loss: 0.2673533876077272


Epoch 138/200: 100%|██████████| 1/1 [00:00<00:00,  2.37batch/s]


Epoch 138/200, Loss: 0.28285378441588754


Epoch 139/200: 100%|██████████| 1/1 [00:00<00:00,  2.29batch/s]


Epoch 139/200, Loss: 0.26890618403974387


Epoch 140/200: 100%|██████████| 1/1 [00:00<00:00,  2.26batch/s]


Epoch 140/200, Loss: 0.23881655573079894


Epoch 141/200: 100%|██████████| 1/1 [00:00<00:00,  2.35batch/s]


Epoch 141/200, Loss: 0.2554199706896075


Epoch 142/200: 100%|██████████| 1/1 [00:00<00:00,  2.34batch/s]


Epoch 142/200, Loss: 0.26913777264832917


Epoch 143/200: 100%|██████████| 1/1 [00:00<00:00,  2.35batch/s]


Epoch 143/200, Loss: 0.25351160565332975


Epoch 144/200: 100%|██████████| 1/1 [00:00<00:00,  2.26batch/s]


Epoch 144/200, Loss: 0.2641628338522943


Epoch 145/200: 100%|██████████| 1/1 [00:00<00:00,  2.35batch/s]


Epoch 145/200, Loss: 0.27274436393885737


Epoch 146/200: 100%|██████████| 1/1 [00:00<00:00,  2.31batch/s]


Epoch 146/200, Loss: 0.2546969681944964


Epoch 147/200: 100%|██████████| 1/1 [00:00<00:00,  2.25batch/s]


Epoch 147/200, Loss: 0.23486282383549092


Epoch 148/200: 100%|██████████| 1/1 [00:00<00:00,  2.24batch/s]


Epoch 148/200, Loss: 0.23800762293727268


Epoch 149/200: 100%|██████████| 1/1 [00:00<00:00,  2.36batch/s]


Epoch 149/200, Loss: 0.2481584871254329


Epoch 150/200: 100%|██████████| 1/1 [00:00<00:00,  2.36batch/s]


Epoch 150/200, Loss: 0.23139934095421005


Epoch 151/200: 100%|██████████| 1/1 [00:00<00:00,  2.22batch/s]


Epoch 151/200, Loss: 0.23796970245478274


Epoch 152/200: 100%|██████████| 1/1 [00:00<00:00,  2.30batch/s]


Epoch 152/200, Loss: 0.23788653978928256


Epoch 153/200: 100%|██████████| 1/1 [00:00<00:00,  2.36batch/s]


Epoch 153/200, Loss: 0.24599510145240594


Epoch 154/200: 100%|██████████| 1/1 [00:00<00:00,  2.30batch/s]


Epoch 154/200, Loss: 0.25587788052091903


Epoch 155/200: 100%|██████████| 1/1 [00:00<00:00,  2.29batch/s]


Epoch 155/200, Loss: 0.2689766247889825


Epoch 156/200: 100%|██████████| 1/1 [00:00<00:00,  2.31batch/s]


Epoch 156/200, Loss: 0.2580175388928702


Epoch 157/200: 100%|██████████| 1/1 [00:00<00:00,  2.33batch/s]


Epoch 157/200, Loss: 0.25749671543599106


Epoch 158/200: 100%|██████████| 1/1 [00:00<00:00,  2.35batch/s]


Epoch 158/200, Loss: 0.24047192672982678


Epoch 159/200: 100%|██████████| 1/1 [00:00<00:00,  2.30batch/s]


Epoch 159/200, Loss: 0.22888386052467727


Epoch 160/200: 100%|██████████| 1/1 [00:00<00:00,  2.36batch/s]


Epoch 160/200, Loss: 0.24891812914782868


Epoch 161/200: 100%|██████████| 1/1 [00:00<00:00,  2.37batch/s]


Epoch 161/200, Loss: 0.2512478377736573


Epoch 162/200: 100%|██████████| 1/1 [00:00<00:00,  2.34batch/s]


Epoch 162/200, Loss: 0.23601971102470998


Epoch 163/200: 100%|██████████| 1/1 [00:00<00:00,  2.28batch/s]


Epoch 163/200, Loss: 0.2382282388113838


Epoch 164/200: 100%|██████████| 1/1 [00:00<00:00,  2.09batch/s]


Epoch 164/200, Loss: 0.22437252703821287


Epoch 165/200: 100%|██████████| 1/1 [00:00<00:00,  2.30batch/s]


Epoch 165/200, Loss: 0.2426154320585608


Epoch 166/200: 100%|██████████| 1/1 [00:00<00:00,  2.36batch/s]


Epoch 166/200, Loss: 0.2262584624030361


Epoch 167/200: 100%|██████████| 1/1 [00:00<00:00,  2.35batch/s]


Epoch 167/200, Loss: 0.2442097527590314


Epoch 168/200: 100%|██████████| 1/1 [00:00<00:00,  2.22batch/s]


Epoch 168/200, Loss: 0.24343540890965543


Epoch 169/200: 100%|██████████| 1/1 [00:00<00:00,  2.39batch/s]


Epoch 169/200, Loss: 0.2347068192383241


Epoch 170/200: 100%|██████████| 1/1 [00:00<00:00,  2.30batch/s]


Epoch 170/200, Loss: 0.236167488769362


Epoch 171/200: 100%|██████████| 1/1 [00:00<00:00,  2.36batch/s]


Epoch 171/200, Loss: 0.23442044800945688


Epoch 172/200: 100%|██████████| 1/1 [00:00<00:00,  2.26batch/s]


Epoch 172/200, Loss: 0.22497508160969507


Epoch 173/200: 100%|██████████| 1/1 [00:00<00:00,  2.30batch/s]


Epoch 173/200, Loss: 0.2243352664435016


Epoch 174/200: 100%|██████████| 1/1 [00:00<00:00,  2.26batch/s]


Epoch 174/200, Loss: 0.2237959615553596


Epoch 175/200: 100%|██████████| 1/1 [00:00<00:00,  2.31batch/s]


Epoch 175/200, Loss: 0.24486171575179988


Epoch 176/200: 100%|██████████| 1/1 [00:00<00:00,  2.20batch/s]


Epoch 176/200, Loss: 0.2349195810300963


Epoch 177/200: 100%|██████████| 1/1 [00:00<00:00,  2.27batch/s]


Epoch 177/200, Loss: 0.23257313045074363


Epoch 178/200: 100%|██████████| 1/1 [00:00<00:00,  2.32batch/s]


Epoch 178/200, Loss: 0.24080777231470812


Epoch 179/200: 100%|██████████| 1/1 [00:00<00:00,  2.32batch/s]


Epoch 179/200, Loss: 0.24081791931738344


Epoch 180/200: 100%|██████████| 1/1 [00:00<00:00,  2.28batch/s]


Epoch 180/200, Loss: 0.24926546852969164


Epoch 181/200: 100%|██████████| 1/1 [00:00<00:00,  2.33batch/s]


Epoch 181/200, Loss: 0.24177768166243496


Epoch 182/200: 100%|██████████| 1/1 [00:00<00:00,  2.34batch/s]


Epoch 182/200, Loss: 0.24054432076601576


Epoch 183/200: 100%|██████████| 1/1 [00:00<00:00,  2.23batch/s]


Epoch 183/200, Loss: 0.22986362822952547


Epoch 184/200: 100%|██████████| 1/1 [00:00<00:00,  2.29batch/s]


Epoch 184/200, Loss: 0.22745733146023536


Epoch 185/200: 100%|██████████| 1/1 [00:00<00:00,  2.34batch/s]


Epoch 185/200, Loss: 0.23552140334088886


Epoch 186/200: 100%|██████████| 1/1 [00:00<00:00,  2.31batch/s]


Epoch 186/200, Loss: 0.24546284090320114


Epoch 187/200: 100%|██████████| 1/1 [00:00<00:00,  2.05batch/s]


Epoch 187/200, Loss: 0.20836812720309741


Epoch 188/200: 100%|██████████| 1/1 [00:00<00:00,  2.21batch/s]


Epoch 188/200, Loss: 0.22540419823989005


Epoch 189/200: 100%|██████████| 1/1 [00:00<00:00,  2.34batch/s]


Epoch 189/200, Loss: 0.20514113076829482


Epoch 190/200: 100%|██████████| 1/1 [00:00<00:00,  2.38batch/s]


Epoch 190/200, Loss: 0.2212697111702125


Epoch 191/200: 100%|██████████| 1/1 [00:00<00:00,  2.27batch/s]


Epoch 191/200, Loss: 0.2217432830594979


Epoch 192/200: 100%|██████████| 1/1 [00:00<00:00,  2.29batch/s]


Epoch 192/200, Loss: 0.22868025475846868


Epoch 193/200: 100%|██████████| 1/1 [00:00<00:00,  2.36batch/s]


Epoch 193/200, Loss: 0.24122046241037812


Epoch 194/200: 100%|██████████| 1/1 [00:00<00:00,  2.34batch/s]


Epoch 194/200, Loss: 0.2226783440980528


Epoch 195/200: 100%|██████████| 1/1 [00:00<00:00,  1.98batch/s]


Epoch 195/200, Loss: 0.21286412787490655


Epoch 196/200: 100%|██████████| 1/1 [00:00<00:00,  2.08batch/s]


Epoch 196/200, Loss: 0.20958035093333038


Epoch 197/200: 100%|██████████| 1/1 [00:00<00:00,  2.13batch/s]


Epoch 197/200, Loss: 0.21119573624205906


Epoch 198/200: 100%|██████████| 1/1 [00:00<00:00,  2.24batch/s]


Epoch 198/200, Loss: 0.21428752814452828


Epoch 199/200: 100%|██████████| 1/1 [00:00<00:00,  2.23batch/s]


Epoch 199/200, Loss: 0.205560519626098


Epoch 200/200: 100%|██████████| 1/1 [00:00<00:00,  2.38batch/s]

Epoch 200/200, Loss: 0.2028494586369821





In [108]:
test_dataset = CheXpertDataset(test, '/kaggle/input/', 
                                mode='val', transforms = transform)
test_loader = DataLoader(
    test_dataset,
    batch_size = 32,
    num_workers = 4,
    pin_memory=True)

In [109]:
model.eval()

all_labels = []  # To store all the true labels
all_preds = []  # To store all the predicted probabilities

# Evaluate the model on the validation set
with torch.no_grad():
    for inputs, labels in tqdm(test_loader, desc="Testing", unit="batch"):
        inputs, labels = inputs.to(device), labels.to(device)

        # Get the logits from the model
        outputs = model(inputs)

        # Apply sigmoid to the outputs to get probabilities
        outputs = torch.sigmoid(outputs)

        # Collect the true labels and predictions
        all_labels.append(labels.cpu().numpy())  # Convert to numpy and store
        all_preds.append(outputs.cpu().numpy())  # Convert to numpy and store

# Convert the list of arrays into a single numpy array for true labels and predictions
all_labels = np.concatenate(all_labels, axis=0)
all_preds = np.concatenate(all_preds, axis=0)

# Compute AUROC for multi-label classification
auroc_scores = []
for i in range(all_labels.shape[1]):  # Loop through each label
    auroc = roc_auc_score(all_labels[:, i], all_preds[:, i])  # Calculate AUROC for each label
    auroc_scores.append(auroc)

# Calculate the mean AUROC score across all labels
mean_auroc = np.mean(auroc_scores)
print(f'Mean AUROC: {mean_auroc:.4f}')

Testing: 100%|██████████| 598/598 [00:37<00:00, 15.96batch/s]


Mean AUROC: 0.5244


In [110]:
# Function to calculate and plot AUROC for multi-label classification
def plot_auroc(model, dataloader, num_classes, device):
    model.eval()
    
    true_labels = []
    predicted_probs = []
    
    # Collect true labels and predicted probabilities
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Get raw model output (logits)
            outputs = model(inputs)
            
            # Apply sigmoid to get probabilities for each label
            probs = torch.sigmoid(outputs)
            
            true_labels.append(labels.cpu().numpy())
            predicted_probs.append(probs.cpu().numpy())
    
    # Convert to numpy arrays
    true_labels = np.vstack(true_labels)
    predicted_probs = np.vstack(predicted_probs)

    # Calculate AUROC for each label
    fpr = {}
    tpr = {}
    roc_auc = {}

    # Iterate over each label and compute the ROC curve and AUROC
    for i in range(num_classes):
        fpr[i], tpr[i], _ = roc_curve(true_labels[:, i], predicted_probs[:, i])
        roc_auc[i] = roc_auc_score(true_labels[:, i], predicted_probs[:, i])

    # Plot the ROC curve for each label
    plt.figure(figsize=(10, 8))
    for i in range(num_classes):
        plt.plot(fpr[i], tpr[i], label=f'Class {i+1} (AUROC = {roc_auc[i]:.2f})')
    
    # Plot the diagonal (random guess line)
    plt.plot([0, 1], [0, 1], 'k--', label='Random guess (AUROC = 0.5)')
    
    # Add labels and legend
    plt.title('Multi-label (ROC) Curve')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.legend(loc='lower right')
    plt.show()
    
    # Return mean AUROC
    mean_auroc = np.mean(list(roc_auc.values()))
    print(f'Mean AUROC across all classes: {mean_auroc:.2f}')
    return mean_auroc

model.to(device)
plot_auroc(model, test_loader, num_classes=14, device='cuda')

KeyboardInterrupt: 