In [18]:
#pip install torch torchvision numpy pandas opencv-python scikit-learn

In [19]:
import os
import cv2
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import time as t
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
from torchvision import transforms

from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split

In [20]:
import warnings
warnings.filterwarnings("ignore")

### Setting and parameters

In [21]:
IMG_SIZE = (224, 224)
BATCH_SIZE = 32
NUM_CLASSES = 14
NUM_EPOCHS = 20
LR = 1e-5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_DIR_TRAIN = '/kaggle/input/grand-xray-slam-division-b/train2/'

### Load train data

In [22]:
try:
    train_df = pd.read_csv('/kaggle/input/grand-xray-slam-division-b/train2.csv')
    print(f"Loaded train2.csv with {len(train_df)} rows")
except FileNotFoundError:
    print("Error: train2.csv not found. Ensure dataset is attached.")
    raise

label_columns = [
    'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum',
    'Fracture', 'Lung Lesion', 'Lung Opacity', 'No Finding', 'Pleural Effusion',
    'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices'
]

missing_cols = [col for col in label_columns if col not in train_df.columns]
if missing_cols:
    raise KeyError(f"Missing columns: {missing_cols}")

subset_frac = 0.06
df_small = train_df.sample(frac=subset_frac, random_state=42)

# split into train and validation (80/20)
train_data, val_data = train_test_split(
    df_small, test_size=0.2, random_state=42
)
print(f"Train samples: {len(train_data)}, Validation samples: {len(val_data)}")

Loaded train2.csv with 108494 rows
Train samples: 5208, Validation samples: 1302


### Augmentations

In [23]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

train_transform = A.Compose([
    A.LongestMaxSize(max_size=IMG_SIZE[0]),   
    A.PadIfNeeded(min_height=IMG_SIZE[0], min_width=IMG_SIZE[1], border_mode=0),  
    
    A.HorizontalFlip(p=0.5),
    A.ShiftScaleRotate(
        shift_limit=0.02, scale_limit=0.05, rotate_limit=5, 
        border_mode=0, p=0.3
    ),
    A.RandomBrightnessContrast(
        brightness_limit=0.05, contrast_limit=0.05, p=0.3
    ),
    A.CLAHE(clip_limit=2.0, tile_grid_size=(8,8), p=0.2),
    
    A.Normalize(mean=(0.5,), std=(0.5,)),  
    ToTensorV2()
])

val_transform = A.Compose([
    A.LongestMaxSize(max_size=IMG_SIZE[0]),
    A.PadIfNeeded(min_height=IMG_SIZE[0], min_width=IMG_SIZE[1], border_mode=0),
    A.Normalize(mean=(0.5,), std=(0.5,)),
    ToTensorV2()
])

### Processing images

In [24]:
class ChestXRayDataset(Dataset):
    def __init__(self, df: pd.DataFrame, image_dir: str, img_size=IMG_SIZE, 
                             is_test=False, label_cols=None, transform=None):
        self.df = df.reset_index(drop=True)
        self.image_dir = image_dir
        self.img_size = img_size
        self.is_test = is_test
        self.label_cols = label_cols
        self.transform = transform

        if not os.path.exists(self.image_dir):
            raise FileNotFoundError(f"Image directory {self.image_dir} not found.")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_name = row['Image_name']
        img_path = os.path.join(self.image_dir, img_name)
    
        # load grayscale image
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        if img is None or img.size == 0:
            img = np.zeros(self.img_size, dtype=np.uint8)
        else:
            img = cv2.resize(img, self.img_size)
    
        # make (H, W, 1) for albumentations
        img = np.expand_dims(img, axis=-1)
    
        # augmentation
        if self.transform:
            # albumentations expects dict
            img = self.transform(image=img)["image"]
        else:
            # fallback: to tensor
            img = transforms.ToTensor()(img)
    
        if self.is_test:
            return img
        else:
            labels = torch.tensor(row[self.label_cols].values.astype(np.float32))
            return img, labels

### Wrapper

In [25]:
class CNNWrapper(nn.Module):
    def __init__(self, model_name, num_classes=NUM_CLASSES):
        super().__init__()
        self.backbone, self.feature_dim = self._get_backbone(model_name)
        self.fc = nn.Sequential(
            nn.Linear(self.feature_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        features = self.backbone(x)
        return self.fc(features)


    def _get_backbone(self, model_name, pretrained=True):
        if model_name == "resnet18":
            model = models.resnet18(pretrained=pretrained)
            model.conv1 = nn.Conv2d(1, 64, 7, 2, 3, bias=False)
            feature_dim = model.fc.in_features
            model.fc = nn.Identity()

        elif model_name == "resnext50":
            model = models.resnext50_32x4d(pretrained=pretrained)
            model.conv1 = nn.Conv2d(1, 64, 7, 2, 3, bias=False)
            feature_dim = model.fc.in_features
            model.fc = nn.Identity()

        elif model_name == "efficientnet_b0":
            model = models.efficientnet_b0(pretrained=pretrained)
            model.features[0][0] = nn.Conv2d(1, 32, 3, 2, 1, bias=False)
            feature_dim = model.classifier[1].in_features
            model.classifier = nn.Identity()

        elif model_name == "densenet121":
            model = models.densenet121(pretrained=pretrained)
            model.features.conv0 = nn.Conv2d(1, 64, 7, 2, 3, bias=False)
            feature_dim = model.classifier.in_features
            model.classifier = nn.Identity()

        elif model_name == "simplecnn":
            model = SimpleCNN()
            feature_dim = model.fc_out_layer.in_features
            model.fc_out = nn.Identity()

        else:
            raise ValueError(f"Unknown model: {model_name}")

        return model, feature_dim


### Baseline CNN model

In [26]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3)
        self.pool = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.conv3 = nn.Conv2d(64, 128, 3)
        self.flattened_size = 128 * 26 * 26
        self.fc = nn.Linear(self.flattened_size, 128)
        self.dropout = nn.Dropout(0.5)
        self.fc_out_layer = nn.Linear(128, num_classes)
        self.fc_out = nn.Identity()

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = F.relu(self.conv3(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc(x))
        x = self.dropout(x)
        return x  # feature vector


### Metrics

In [27]:
class MetricCalculator:
    def __init__(self, class_names=label_columns):
        self.class_names = class_names
    
    def compute_auc(self, y_true, y_pred):
        """Вычисляет средний AUC и AUC по каждому классу."""
        aucs = []
        for i in range(y_true.shape[1]):
            try:
                auc = roc_auc_score(y_true[:, i], y_pred[:, i])
            except ValueError:
                auc = np.nan  # если для класса нет положительных/отрицательных примеров
            aucs.append(auc)
        aucs = np.array(aucs, dtype=np.float32)
        return np.nanmean(aucs), aucs
    
    def print_auc(self, auc_mean, auc_per_class, phase="Train"):
        print(f"\n{phase} AUC: {auc_mean:.4f}")


### Plots

In [28]:
def plot_learning_curves(model_name, history, num_epochs=NUM_EPOCHS):
    epochs = np.arange(1, num_epochs+1)
    os.makedirs("plots", exist_ok=True)
    
    plt.figure(figsize=(10,4))
    # loss
    plt.subplot(1,2,1)
    plt.plot(epochs, history['train']['loss'], label='train loss')
    plt.plot(epochs, history['val']['loss'], label='val loss')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend()
    plt.title(f'Loss {model_name}')
    # auc
    plt.subplot(1,2,2)
    plt.plot(epochs, history['train']['auc'], label='train AUC')
    plt.plot(epochs, history['val']['auc'], label='val AUC')
    plt.xlabel('Epoch'); plt.ylabel('AUC'); plt.legend()
    plt.title(f'AUC {model_name}')

    plt.tight_layout()

    save_path = os.path.join("plots", f"{model_name}_learning_curves_{num_epochs}_epochs.png")
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved learning curves to: {save_path}")

In [29]:
def bar_aucs(aucs_per_classes, model_name="model"):
    os.makedirs("plots", exist_ok=True)

    plt.figure(figsize=(30, 4))
    plt.bar(label_columns, aucs_per_classes)
    plt.xlabel('Class')
    plt.ylabel('AUC')
    plt.title(f'AUCs by Classes ({model_name})')

    save_path = os.path.join("plots", f"{model_name}_aucs_per_class.png")
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved per-class AUC bar chart to: {save_path}")

### Loaders

In [30]:
train_dataset = ChestXRayDataset(train_data, IMAGE_DIR_TRAIN, img_size=IMG_SIZE, 
                                 is_test=False, label_cols=label_columns, transform=train_transform)

val_dataset = ChestXRayDataset(val_data, IMAGE_DIR_TRAIN, img_size=IMG_SIZE, 
                               is_test=False, label_cols=label_columns, transform=val_transform)


train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=3, pin_memory=True)
val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=3, pin_memory=True)

### Save results

In [31]:
def save_results_csv(model_name, history, num_epochs, lr, time, filepath="results.csv"):
    row = {
        "Model": model_name,
        "AUC": max(history['val']['auc']),
        "epochs": num_epochs,
        "lr": lr,
        "time": time
    }

    if os.path.exists(filepath):
        df = pd.read_csv(filepath)
        df = pd.concat([df, pd.DataFrame([row])], ignore_index=True)
    else:
        df = pd.DataFrame([row])

    df.to_csv(filepath, index=False)

### Train loop

In [32]:
def training_and_validation(model_name, model, optimizer, criterion, num_epochs=NUM_EPOCHS, lr=LR):
    print(f'Model_name: {model_name}')
    
    metric_calculator = MetricCalculator()

    # initialize history (только loss и AUC)
    history = {
        'train': {'loss': [], 'auc': []},
        'val': {'loss': [], 'auc': []}
    }

    best_val_auc = 0.0
    best_epoch = 0
    best_val_auc_per_class = None
    total_train_time = 0.0
    
    for epoch in range(num_epochs):
         # ---------------------- TRAIN ----------------------
        start_time = t.time()
        model.train()
        running_loss = 0.0
        train_labels, train_preds = [], []

        for batch_idx, (imgs, labels) in enumerate(train_loader):
        
            imgs = imgs.to(DEVICE, dtype=torch.float)
            labels = labels.to(DEVICE, dtype=torch.float)
        
            optimizer.zero_grad()
            logits = model(imgs)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
        
            running_loss += loss.item() * imgs.size(0)
            probs = torch.sigmoid(logits).detach().cpu().numpy()
            train_preds.append(probs)
            train_labels.append(labels.cpu().numpy())

        epoch_train_time = t.time() - start_time
        total_train_time += epoch_train_time  
        
        train_preds = np.vstack(train_preds)
        train_labels = np.vstack(train_labels)
        train_epoch_loss = running_loss / len(train_dataset)
        train_auc_mean, train_auc_per_class = metric_calculator.compute_auc(train_labels, train_preds)

        history['train']['loss'].append(train_epoch_loss)
        history['train']['auc'].append(train_auc_mean)

        print(f"Epoch {epoch+1} Train loss: {train_epoch_loss:.2f}")
        metric_calculator.print_auc(train_auc_mean, train_auc_per_class, phase="Train")

        # ---------------------- VALIDATION ----------------------
        model.eval()
        val_loss = 0.0
        val_preds = []
        val_labels = []

        with torch.no_grad():
            for imgs, labels in val_loader:
                imgs = imgs.to(DEVICE, dtype=torch.float)
                labels = labels.to(DEVICE, dtype=torch.float)
        
                logits = model(imgs)
                loss = criterion(logits, labels)
                val_loss += loss.item() * imgs.size(0)
        
                probs = torch.sigmoid(logits).detach().cpu().numpy()
                val_preds.append(probs)
                val_labels.append(labels.cpu().numpy())

                del imgs, labels, logits, probs
        
        val_preds = np.vstack(val_preds)
        val_labels = np.vstack(val_labels)
        
        val_epoch_loss = val_loss / len(val_dataset)
        val_auc_mean, val_auc_per_class = metric_calculator.compute_auc(val_labels, val_preds)

        history['val']['loss'].append(val_epoch_loss)
        history['val']['auc'].append(val_auc_mean)

        print(f"Epoch {epoch+1} Validation loss: {val_epoch_loss:.2f}")
        metric_calculator.print_auc(val_auc_mean, val_auc_per_class, phase="Validation")


        # ---------------------- SAVE BEST MODEL ----------------------
        if epoch>6 and val_auc_mean > best_val_auc:
            best_val_auc = val_auc_mean
            best_val_auc_per_class = val_auc_per_class.copy()  # сохраняем для bar-графика
            best_epoch = epoch + 1
            save_path = f"best_{model_name}_epoch_{best_epoch}_auc{best_val_auc:.4f}.pth"
            torch.save(model.state_dict(), save_path)
            print(f"Saved new best model at epoch {best_epoch} with AUC={best_val_auc:.4f}: {save_path}")

    total_time_minutes = total_train_time/60
    
    print(f"Best validation AUC={best_val_auc:.4f} at epoch {best_epoch}")
    print(f"{model_name} was trained for {total_time_minutes:.0f} minutes")


    # ---------------------- PLOTS & SAVE HISTORY ----------------------
    plot_learning_curves(model_name, history, num_epochs)
    bar_aucs(best_val_auc_per_class, model_name)

    save_results_csv(model_name, history, num_epochs, lr, total_time_minutes)

### Launching an experiment

In [33]:
MODEL_NAMES = ["simplecnn", "resnet18", "resnext50", "efficientnet_b0", "densenet121"]

In [34]:
for model_name in MODEL_NAMES:
    model = CNNWrapper(model_name)
    model.to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    criterion = nn.BCEWithLogitsLoss()
    training_and_validation(model_name, model, optimizer, criterion)

Model_name: simplecnn
Epoch 1 Train loss: 0.63

Train AUC: 0.4906
Epoch 1 Validation loss: 0.55

Validation AUC: 0.4504
Epoch 2 Train loss: 0.58

Train AUC: 0.4913
Epoch 2 Validation loss: 0.54

Validation AUC: 0.5472
Epoch 3 Train loss: 0.56

Train AUC: 0.5131
Epoch 3 Validation loss: 0.53

Validation AUC: 0.6486
Epoch 4 Train loss: 0.54

Train AUC: 0.5528
Epoch 4 Validation loss: 0.51

Validation AUC: 0.6977
Epoch 5 Train loss: 0.52

Train AUC: 0.5984
Epoch 5 Validation loss: 0.50

Validation AUC: 0.7285
Epoch 6 Train loss: 0.51

Train AUC: 0.6353
Epoch 6 Validation loss: 0.48

Validation AUC: 0.7449
Epoch 7 Train loss: 0.49

Train AUC: 0.6647
Epoch 7 Validation loss: 0.46

Validation AUC: 0.7608
Epoch 8 Train loss: 0.48

Train AUC: 0.6821
Epoch 8 Validation loss: 0.45

Validation AUC: 0.7684
Saved new best model at epoch 8 with AUC=0.7684: best_simplecnn_epoch_8_auc0.7684.pth
Epoch 9 Train loss: 0.48

Train AUC: 0.6971
Epoch 9 Validation loss: 0.45

Validation AUC: 0.7745
Saved new 

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 239MB/s]


Model_name: resnet18
Epoch 1 Train loss: 0.58

Train AUC: 0.6629
Epoch 1 Validation loss: 0.48

Validation AUC: 0.7785
Epoch 2 Train loss: 0.44

Train AUC: 0.7627
Epoch 2 Validation loss: 0.41

Validation AUC: 0.8047
Epoch 3 Train loss: 0.40

Train AUC: 0.7928
Epoch 3 Validation loss: 0.38

Validation AUC: 0.8138
Epoch 4 Train loss: 0.38

Train AUC: 0.8053
Epoch 4 Validation loss: 0.37

Validation AUC: 0.8228
Epoch 5 Train loss: 0.37

Train AUC: 0.8165
Epoch 5 Validation loss: 0.37

Validation AUC: 0.8271
Epoch 6 Train loss: 0.36

Train AUC: 0.8262
Epoch 6 Validation loss: 0.36

Validation AUC: 0.8355
Epoch 7 Train loss: 0.36

Train AUC: 0.8332
Epoch 7 Validation loss: 0.36

Validation AUC: 0.8404
Epoch 8 Train loss: 0.35

Train AUC: 0.8408
Epoch 8 Validation loss: 0.36

Validation AUC: 0.8438
Saved new best model at epoch 8 with AUC=0.8438: best_resnet18_epoch_8_auc0.8438.pth
Epoch 9 Train loss: 0.35

Train AUC: 0.8466
Epoch 9 Validation loss: 0.36

Validation AUC: 0.8450
Saved new be

Downloading: "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth" to /root/.cache/torch/hub/checkpoints/resnext50_32x4d-7cdf4587.pth
100%|██████████| 95.8M/95.8M [00:00<00:00, 210MB/s]


Model_name: resnext50
Epoch 1 Train loss: 0.52

Train AUC: 0.7028
Epoch 1 Validation loss: 0.41

Validation AUC: 0.7932
Epoch 2 Train loss: 0.40

Train AUC: 0.7818
Epoch 2 Validation loss: 0.38

Validation AUC: 0.8134
Epoch 3 Train loss: 0.37

Train AUC: 0.8084
Epoch 3 Validation loss: 0.37

Validation AUC: 0.8212
Epoch 4 Train loss: 0.36

Train AUC: 0.8222
Epoch 4 Validation loss: 0.37

Validation AUC: 0.8282
Epoch 5 Train loss: 0.35

Train AUC: 0.8295
Epoch 5 Validation loss: 0.36

Validation AUC: 0.8337
Epoch 6 Train loss: 0.34

Train AUC: 0.8424
Epoch 6 Validation loss: 0.36

Validation AUC: 0.8349
Epoch 7 Train loss: 0.33

Train AUC: 0.8510
Epoch 7 Validation loss: 0.36

Validation AUC: 0.8388
Epoch 8 Train loss: 0.32

Train AUC: 0.8603
Epoch 8 Validation loss: 0.36

Validation AUC: 0.8400
Saved new best model at epoch 8 with AUC=0.8400: best_resnext50_epoch_8_auc0.8400.pth
Epoch 9 Train loss: 0.31

Train AUC: 0.8718
Epoch 9 Validation loss: 0.36

Validation AUC: 0.8429
Saved new 

Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth


Saved per-class AUC bar chart to: plots/resnext50_aucs_per_class.png


100%|██████████| 20.5M/20.5M [00:00<00:00, 188MB/s]


Model_name: efficientnet_b0
Epoch 1 Train loss: 0.64

Train AUC: 0.6160
Epoch 1 Validation loss: 0.58

Validation AUC: 0.7423
Epoch 2 Train loss: 0.52

Train AUC: 0.7371
Epoch 2 Validation loss: 0.49

Validation AUC: 0.7584
Epoch 3 Train loss: 0.47

Train AUC: 0.7567
Epoch 3 Validation loss: 0.46

Validation AUC: 0.7728
Epoch 4 Train loss: 0.45

Train AUC: 0.7684
Epoch 4 Validation loss: 0.44

Validation AUC: 0.7842
Epoch 5 Train loss: 0.43

Train AUC: 0.7843
Epoch 5 Validation loss: 0.42

Validation AUC: 0.7925
Epoch 6 Train loss: 0.41

Train AUC: 0.7906
Epoch 6 Validation loss: 0.39

Validation AUC: 0.8055
Epoch 7 Train loss: 0.39

Train AUC: 0.8012
Epoch 7 Validation loss: 0.38

Validation AUC: 0.8095
Epoch 8 Train loss: 0.38

Train AUC: 0.8070
Epoch 8 Validation loss: 0.38

Validation AUC: 0.8150
Saved new best model at epoch 8 with AUC=0.8150: best_efficientnet_b0_epoch_8_auc0.8150.pth
Epoch 9 Train loss: 0.38

Train AUC: 0.8146
Epoch 9 Validation loss: 0.38

Validation AUC: 0.817

Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to /root/.cache/torch/hub/checkpoints/densenet121-a639ec97.pth
100%|██████████| 30.8M/30.8M [00:00<00:00, 230MB/s]


Model_name: densenet121
Epoch 1 Train loss: 0.62

Train AUC: 0.5954
Epoch 1 Validation loss: 0.53

Validation AUC: 0.7538
Epoch 2 Train loss: 0.48

Train AUC: 0.7434
Epoch 2 Validation loss: 0.44

Validation AUC: 0.7911
Epoch 3 Train loss: 0.42

Train AUC: 0.7789
Epoch 3 Validation loss: 0.40

Validation AUC: 0.8117
Epoch 4 Train loss: 0.39

Train AUC: 0.7932
Epoch 4 Validation loss: 0.38

Validation AUC: 0.8212
Epoch 5 Train loss: 0.38

Train AUC: 0.8089
Epoch 5 Validation loss: 0.37

Validation AUC: 0.8270
Epoch 6 Train loss: 0.37

Train AUC: 0.8154
Epoch 6 Validation loss: 0.37

Validation AUC: 0.8317
Epoch 7 Train loss: 0.36

Train AUC: 0.8252
Epoch 7 Validation loss: 0.37

Validation AUC: 0.8344
Epoch 8 Train loss: 0.36

Train AUC: 0.8339
Epoch 8 Validation loss: 0.36

Validation AUC: 0.8356
Saved new best model at epoch 8 with AUC=0.8356: best_densenet121_epoch_8_auc0.8356.pth
Epoch 9 Train loss: 0.35

Train AUC: 0.8407
Epoch 9 Validation loss: 0.36

Validation AUC: 0.8382
Saved 