In [None]:
import pandas as pd
import os
from glob import glob
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim


selected_folders = 'images_001|images_002|images_003|images_004|images_005'

# Load and filter data
data = pd.read_csv('./data/Data_Entry_2017.csv')
data = data[data['Patient Age'] < 100]  # Removing invalid ages
data_image_paths = {os.path.basename(x): x for x in glob(os.path.join('.', 'data', 'images*', '*', '*.png'))}
data['path'] = data['Image Index'].map(data_image_paths.get)
data['Patient Age'] = data['Patient Age'].astype(int)
data['Finding Labels'] = data['Finding Labels'].map(lambda x: x.replace('No Finding', ''))
mask = data['path'].str.contains(selected_folders)
data = data[mask]


# Process labels
from itertools import chain
all_labels = np.unique(list(chain(*data['Finding Labels'].map(lambda x: x.split('|')).tolist())))
all_labels = [x for x in all_labels if x]
for label in all_labels:
    if len(label) > 1:  # Avoid empty labels
        data[label] = data['Finding Labels'].map(lambda findings: 1.0 if label in findings else 0)

# Filter labels to keep
MIN_CASES = 1000
all_labels = [label for label in all_labels if data[label].sum() > MIN_CASES]
data['disease_vec'] = data[all_labels].values.tolist()

# Print information
print(f"Clean Labels ({len(all_labels)}): {[(label, int(data[label].sum())) for label in all_labels]}")
print('Scans found:', len(data_image_paths), ', Total Headers:', data.shape[0])

In [None]:
class ChestXrayDataset(Dataset):
    def __init__(self, dataframe, image_paths, labels, transform=None):
        self.dataframe = dataframe
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        image_path = self.dataframe.iloc[idx]['path']
        image = Image.open(image_path).convert('RGB')
        label = torch.tensor(self.dataframe.iloc[idx][self.labels].astype(float).values)

        if self.transform:
            image = self.transform(image)

        return image, label

# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Split data
from sklearn.model_selection import train_test_split
train_df, test_df = train_test_split(data, test_size=0.20, random_state=2018, stratify=data['Finding Labels'].str[:4])
train_df, valid_df = train_test_split(train_df, test_size=0.10, random_state=2018, stratify=train_df['Finding Labels'].str[:4])

train_dataset = ChestXrayDataset(train_df, data_image_paths, all_labels, transform=transform)
valid_dataset = ChestXrayDataset(valid_df, data_image_paths, all_labels, transform=transform)
test_dataset = ChestXrayDataset(test_df, data_image_paths, all_labels, transform=transform)


train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [None]:
import torch
import torch.nn as nn
from torchvision.models import googlenet
from torchvision.models import vgg16

class GoogleNetModel(nn.Module):
    def __init__(self, num_classes):
        super(GoogleNetModel, self).__init__()
        self.model = googlenet(pretrained=True)
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        x = self.model(x)
        x = self.sigmoid(x)
        return x
    
class VGGModel(nn.Module):
    def __init__(self, num_classes):
        super(VGGModel, self).__init__()
        self.model = vgg16(pretrained=True)
        self.model.classifier[-1] = nn.Linear(self.model.classifier[-1].in_features, num_classes)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.model(x)
        x = self.sigmoid(x)
        return x

class GoogleVGG(nn.Module):
    def __init__(self, num_classes):
        super(GoogleVGG, self).__init__()
        self.googlenet = googlenet(pretrained=True)
        self.vgg16 = vgg16(pretrained=True)

        self.googlenet.fc = nn.Linear(self.googlenet.fc.in_features, num_classes)
        self.vgg16.classifier[-1] = nn.Linear(self.vgg16.classifier[-1].in_features, num_classes)
        
        self.classifier = nn.Sequential(
            nn.Linear(num_classes * 2, num_classes),
            nn.ReLU(),
            nn.Linear(num_classes, num_classes),
            nn.Sigmoid()
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        output_googlenet = self.googlenet(x)
        output_vgg16 = self.vgg16(x)
        
        combined_features = torch.cat((output_googlenet, output_vgg16), dim=1)
        
        output = self.sigmoid(self.classifier(combined_features))
        return output


vgg_model = VGGModel(len(all_labels))
vgg_optimizer = optim.Adam(vgg_model.parameters(), lr=0.00001)
print(vgg_model)
    
google_model = GoogleNetModel(len(all_labels))
google_optimizer = optim.Adam(google_model.parameters(), lr=0.00001)
criterion = nn.BCELoss()
print(google_model)

model = GoogleVGG(len(all_labels))
optimizer = optim.Adam(model.parameters(), lr=0.00001)
criterion = nn.BCELoss()

In [None]:
from tqdm import tqdm
# Training and Validation Loop
def train_and_validate(model, train_loader, valid_loader, criterion, optimizer, num_epochs=20):
    train_loss = []
    val_loss = []
    for epoch in range(num_epochs):
        # Training Phase
        model.train()
        total_train_loss = 0.0
        total_correct = 0
        total_samples = 0
        for images, labels in tqdm(train_loader):
            images, labels = images.cuda(), labels.cuda()
            optimizer.zero_grad()
#             print(images.shape)
            outputs = model(images)
            loss = criterion(outputs.float(), labels.float())
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()
            
            predicted = (outputs > 0.5).float()
            # Count correct predictions
            correct = (predicted == labels).all(dim=1).sum().item()
            total_correct += correct

            # Increment total number of samples
            total_samples += labels.size(0)
        
        avg_train_loss = total_train_loss / len(train_loader)
        train_loss.append(avg_train_loss)
        train_accuracy = total_correct / total_samples * 100
        
        # Validation Phase
        model.eval()
        total_valid_loss = 0.0
        total_correct = 0
        total_samples = 0
        with torch.no_grad():
            for images, labels in tqdm(valid_loader):
                images, labels = images.cuda(), labels.cuda()
                outputs = model(images)
                loss = criterion(outputs.float(), labels.float())
                total_valid_loss += loss.item()
                
                predicted = (outputs > 0.5).float()
                # Count correct predictions
                correct = (predicted == labels).all(dim=1).sum().item()
                total_correct += correct

                # Increment total number of samples
                total_samples += labels.size(0)

        avg_valid_loss = total_valid_loss / len(valid_loader)
        val_loss.append(avg_valid_loss)
        val_accuracy = total_correct / total_samples * 100
        
        
        print(f'Epoch [{epoch+1}/{num_epochs}], '
              f'Train Loss: {avg_train_loss:.4f}, '
              f'Validation Loss: {avg_valid_loss:.4f}')
        print(train_accuracy, val_accuracy)
        
    return train_loss, val_loss

# Continue with model training
train_and_validate(model.cuda(), train_loader, valid_loader, criterion, optimizer, num_epochs=1)

In [None]:
import gc
gc.collect()

In [None]:
# device = torch.device('cpu')

In [None]:
# from tqdm import tqdm 

# def train_and_validate(model, train_loader, valid_loader, criterion, optimizer, num_epochs=20):
#     model.to(device, dtype=torch.float32) 

#     for epoch in range(num_epochs):
#         model.train()
#         total_train_loss, total_valid_loss = 0, 0
#         total_train_correct, total_valid_correct = 0, 0
#         total_train_samples, total_valid_samples = 0, 0

#         # Training
#         for images, labels in tqdm(train_loader):
#             images = images.to(device, dtype=torch.float32)
#             labels = labels.to(device, dtype=torch.float32)
#             optimizer.zero_grad()
#             outputs = model(images)
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()
#             total_train_loss += loss.item()

#             predicted = outputs.round()
#             total_train_correct += (predicted == labels).all(dim=1).sum().item()
#             total_train_samples += labels.size(0)

#         # Validation
#         model.eval()
#         with torch.no_grad():
#             for images, labels in tqdm(valid_loader):
#                 images = images.to(device, dtype=torch.float32)
#                 labels = labels.to(device, dtype=torch.float32)
#                 outputs = model(images)
#                 loss = criterion(outputs, labels)
#                 total_valid_loss += loss.item()

#                 predicted = outputs.round()
#                 total_valid_correct += (predicted == labels).all(dim=1).sum().item()
#                 total_valid_samples += labels.size(0)

#         avg_train_loss = total_train_loss / len(train_loader)
#         avg_valid_loss = total_valid_loss / len(valid_loader)
#         train_accuracy = (total_train_correct / total_train_samples) * 100
#         valid_accuracy = (total_valid_correct / total_valid_samples) * 100

#         print(f'Epoch [{epoch+1}/{num_epochs}]: Train Loss: {avg_train_loss:.4f}, Validation Loss: {avg_valid_loss:.4f}', flush = True)
#         print(f'Train Accuracy: {train_accuracy:.2f}%, Validation Accuracy: {valid_accuracy:.2f}%', flush = True)

# # Call training/validation function
# # train_and_validate(model, train_loader, valid_loader, criterion, optimizer)
# train_and_validate(google_model, train_loader, valid_loader, criterion, google_optimizer)

In [None]:
# import matplotlib.pyplot as plt
# from sklearn.metrics import roc_curve, auc
# from tqdm import tqdm

# # Function to compute the ROC AUC score
# def compute_roc_auc(model, data_loader, num_classes):
#     model.eval()
#     y_true = torch.FloatTensor()
#     y_pred = torch.FloatTensor()
    
#     with torch.no_grad():
#         for images, labels in tqdm(data_loader):
#             outputs = model(images)
#             y_true = torch.cat((y_true, labels), 0)
#             y_pred = torch.cat((y_pred, outputs), 0)

#     roc_auc_dict = {}
#     for i, label in enumerate(all_labels):
#         fpr, tpr, _ = roc_curve(y_true[:, i], y_pred[:, i])
#         roc_auc_dict[label] = auc(fpr, tpr)
#         plt.plot(fpr, tpr, label=f'{label} (AUC = {roc_auc_dict[label]:.2f})')
    
#     plt.title('Receiver Operating Characteristic')
#     plt.legend(loc='lower right')
#     plt.plot([0, 1], [0, 1], 'r--')
#     plt.xlim([0, 1])
#     plt.ylim([0, 1])
#     plt.ylabel('True Positive Rate')
#     plt.xlabel('False Positive Rate')
#     plt.show()

#     return roc_auc_dict

# # Compute and plot ROC AUC
# roc_auc_scores = compute_roc_auc(model, test_loader, len(all_labels))


In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc

# Function to compute the ROC AUC score
def compute_roc_auc(model, data_loader, num_classes):
    model.eval()
    y_true = torch.FloatTensor().cuda()
    y_pred = torch.FloatTensor().cuda()
    
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.cuda(), labels.cuda()
            outputs = model(images)
            y_true = torch.cat((y_true, labels), 0)
            y_pred = torch.cat((y_pred, outputs), 0)

    roc_auc_dict = {}
    for i, label in enumerate(all_labels):
        fpr, tpr, _ = roc_curve(y_true.cpu()[:, i], y_pred.cpu()[:, i])
        roc_auc_dict[label] = auc(fpr, tpr)
        plt.plot(fpr, tpr, label=f'{label} (AUC = {roc_auc_dict[label]:.2f})')
    
    plt.title('Receiver Operating Characteristic')
    plt.legend(loc='lower right')
    plt.plot([0, 1], [0, 1], 'r--')
    plt.xlim([0, 1])
    plt.ylim([0, 1])
    plt.ylabel('True Positive Rate')
    plt.xlabel('False Positive Rate')
    plt.show()

    return roc_auc_dict

# Compute and plot ROC AUC
roc_auc_scores = compute_roc_auc(model, test_loader, len(all_labels))

In [None]:
print(roc_auc_scores, flush = True)

In [None]:
# Visualization of Predictions
def visualize_predictions(model, data_loader, num_images=4):
    model.eval()
    images, labels = next(iter(data_loader))
    images, labels = images.cuda(), labels.cuda()
    outputs = model(images)
    outputs = outputs > 0.5  # Threshold predictions

    fig, axs = plt.subplots(1, num_images, figsize=(15, 10))
    for i in range(num_images):
        axs[i].imshow(images[i].cpu().permute(1, 2, 0))
        axs[i].axis('off')
        disease_labels = ', '.join([all_labels[j] for j in range(outputs.shape[1]) if outputs[i, j] == 1])
        axs[i].set_title(disease_labels)

    plt.show()

# Call visualization function
visualize_predictions(model, test_loader, num_images=4)