In [None]:
import os
import numpy as np
import pandas as pd
import time
import csv
import random
import matplotlib.pyplot as plt
from PIL import Image
from barbar import Bar
import datetime
import time
from tqdm import tqdm 
import yaml


import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision
from torchvision import transforms, datasets, models
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split


import sklearn.metrics as metrics
from sklearn.metrics import roc_auc_score

use_gpu = torch.cuda.is_available()
print(use_gpu)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

class Config(object):
    def __init__(self):
        self.dataset_name = 'nih'
        self.data_dir = '/workspace/DATASETS/NIH_Chest-Xray-14'
        self.data_per = 1
        self.mode       = 'fr'

        self.base_method   = 'coboom'
        self.pre_v         = 'v2'
        self.pre_e         = 300
        self.eval_e        = 230  
        self.pre_b         = 64  
        self.weight_dir    = f'/workspace/singh63/CoBoom/OLD/CoBoom/ckpt/{self.base_method}_{self.pre_v}/resnet18_NIH14_{self.pre_b}_{self.pre_e}'
        self.pre_method    = f'resnet18_NIH14_{self.pre_b}_{self.pre_e}_{self.eval_e}'
        self.backbone_path = os.path.join(self.weight_dir,self.pre_method+'.pth')
        
        self.save_path   = f'./{self.dataset_name}_ckpt/' + self.base_method+'_'+self.pre_v
        self.model_path  = f'{self.save_path}/{self.pre_method}/{self.mode}/{self.data_per}/'
        self.method_name = f'{self.dataset_name}_{self.pre_method}_{self.mode}_{self.data_per}'
        
        self.data_workers = 32
        self.shuffle_dataset=True
        self.random_seed=24

        self.lr = 0.003
        self.learning_rate_min = 0.000001
                
        self.batch_size = 128
        self.test_batch_size = 1
        self.num_classes = 15
        self.resize_size=224
        self.epochs = 300
        
        if self.data_per ==1:
            self.nih_train_df ='./one_train_label_data.csv'
        if self.data_per ==5:
            self.nih_train_df ='./5_train_label_data.csv'
        if self.data_per ==10:
            self.nih_train_df ='./10_train_label_data.csv'
        if self.data_per ==30:
            self.nih_train_df ='./30_train_label_data.csv'
        if self.data_per ==100:
            self.nih_train_df ='./train_label_data.csv'
        
        self.nih_valid_df ='/workspace/DATASETS/NIH_Chest-Xray-14/test_label_data.csv'
 
        os.makedirs(self.save_path, exist_ok=True)
        os.makedirs(self.model_path, exist_ok=True)
opt = Config()

opt.backbone_path

In [None]:
class NIHdataset(torch.utils.data.Dataset):
    def __init__(self, df,class_names,transform,):
        
        self.image_filepaths = df["filename"].values 
        self.transform = transform
        self.pathologies = class_names
        self.pathologies = sorted(self.pathologies)
        self.csv = df
        
        self.labels = []
        for pathology in self.pathologies:
            if pathology in self.csv.columns:
                mask = self.csv[pathology]
            self.labels.append(mask.values)
            
        self.labels = np.asarray(self.labels).T
        self.labels = self.labels.astype(np.float32)
        

    def __getitem__(self, idx):
        img = self.image_filepaths[idx]
        image = Image.open(img).convert('RGB')
        if self.transform:
            image = self.transform(image)
        label = self.labels[idx]

        return image,label

    def __len__(self):
        return len(self.image_filepaths)   

def get_transform(image_size,phase):
    t_list = []
    normalize = transforms.Normalize(mean=[0.0904, 0.2219, 0.4431],
                                     std=[1.0070, 1.0294, 1.0249])
    
#     normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
#                                      std=[0.229, 0.224, 0.225])
    
    if phase == "train":
        t_list = [
                transforms.Resize((image_size,image_size)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomGrayscale(p=0.2),
                transforms.ToTensor(),
                normalize]
        
    if phase == "val":
        t_list = [
                transforms.Resize((image_size,image_size)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize]
    
    transform = transforms.Compose(t_list)
    return transform

def getnihchex14_dataset():
    train_csv = opt.nih_train_df
    valid_csv = opt.nih_valid_df
    data_size = opt.data_per
    
    train_df  = pd.read_csv(train_csv)
    valid_df  = pd.read_csv(valid_csv)
    
#     train_df = train_df.sample(frac=data_size)
#     train_df.to_csv('/workspace/DATASETS/NIH_Chest-Xray-14/10_train_label_data.csv', index=False)
    
    
    class_names = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 
           'Effusion', 'Emphysema', 'Fibrosis', 'Hernia', 'Infiltration', 'Mass', 
           'No Finding', 'Nodule', 'Pleural_Thickening', 'Pneumonia', 'Pneumothorax']
           
    image_size = opt.resize_size
    
    train_transform  = get_transform(image_size, phase ='train')
    valid_transform  = get_transform(image_size, phase ='val')
    
    
    train_dataset = NIHdataset(train_df,class_names,transform=train_transform)
    valid_dataset = NIHdataset(valid_df,class_names,transform=valid_transform)
    
    return train_dataset,valid_dataset

def count_parameters(model):
    params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return params/1000000

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.cuda.manual_seed_all(seed)

train_dataset,valid_dataset = getnihchex14_dataset()
print("Train data length:", len(train_dataset))
print("Valid data length:", len(valid_dataset))

train_loader = DataLoader(train_dataset,batch_size=opt.batch_size,shuffle=True,num_workers=32,pin_memory=True)
test_loader = DataLoader(valid_dataset,batch_size=opt.batch_size,shuffle=True,num_workers=32,pin_memory=True)

In [None]:
# Modified LinearRegression class with improved feature handling
class LinearRegression(torch.nn.Module):
    def __init__(self, output_dim):
        super(LinearRegression, self).__init__()
        self.model = models.resnet18(weights=None)
        self.n_inputs = self.model.fc.in_features
        self.model.fc = nn.Identity()
        
        checkpoint = torch.load(opt.backbone_path, map_location=device)['online']
        state_dict = {}
        length = len(self.model.state_dict())
        for name, param in zip(self.model.state_dict(), list(checkpoint.values())[:length]):
            state_dict[name] = param
        self.model.load_state_dict(state_dict)
        
        self.linear = nn.Sequential(
            nn.Linear(self.n_inputs, output_dim),
            nn.Sigmoid()
        )
    
        
    def forward(self, x):
        x = self.model(x)
        x = x.view(x.size(0), -1)
        return self.linear(x)

# Modified training loop with early stopping and proper initialization
set_seed(opt.random_seed)
logreg = LinearRegression(opt.num_classes)

In [None]:
# Freeze backbone parameters
if opt.mode == 'fr':
    num_params = count_parameters(logreg.linear)
    print("Total Parameter: \t%d" % (num_params*1000000))
    for param in logreg.model.parameters():
        param.requires_grad = False

In [None]:
def computeAUROC(dataGT, dataPRED):
    # Computes area under ROC curve 
    outAUROC = []
    datanpGT = dataGT.cpu().numpy()
    datanpPRED = dataPRED.cpu().numpy()

    for i in range(opt.num_classes):
        try:
            outAUROC.append(roc_auc_score(datanpGT[:, i], datanpPRED[:, i]))
        except ValueError:
            pass
    return outAUROC

In [None]:
# Define Train function properly separated from optimizer initialization
def Train(regressor, dataloaderDownTrain, optimizer, scheduler, criterion):
    # Track batch-wise loss for better monitoring
    regressor.train()
    losstrain = 0
    
    for batchID, (varInput, target) in enumerate(Bar(dataloaderDownTrain)):
        varTarget = target.to(device)
        varInput = varInput.to(device)
        
        # Apply modified forward pass with feature normalization
        with torch.no_grad():
            features = regressor.model(varInput)
            
        varOutput = regressor.linear(features)
        
        lossvalue = criterion(varOutput, varTarget)
        
        # Apply gradient clipping to prevent unstable updates
        optimizer.zero_grad()
        lossvalue.backward()
        torch.nn.utils.clip_grad_norm_(regressor.parameters(), max_norm=1.0)
        optimizer.step()
        
        losstrain += lossvalue.item()
    
    # Step the scheduler based on epoch
    scheduler.step()
    
    # Print current learning rate for monitoring
    current_lr = optimizer.param_groups[0]['lr']
    print(f"Current LR: {current_lr:.6f}")
    
    return losstrain / len(dataloaderDownTrain)

# Modified Test function with improved evaluation
def Test(regressor, dataLoaderTest):
    cudnn.benchmark = True
    outGT = torch.FloatTensor().to(device)
    outPRED = torch.FloatTensor().to(device)
    regressor.eval()
    
    with torch.no_grad():
        for i, (varInput, target) in enumerate(Bar(dataLoaderTest)):
            target = target.to(device)
            outGT = torch.cat((outGT, target), 0)
            outGT = outGT.to(device)
            varInput = varInput.to(device)
            
            # Apply the same feature normalization as in training
            features = regressor.model(varInput)
            varOutput = regressor.linear(features)
            
            outPRED = torch.cat((outPRED, varOutput), 0)
    
    aurocIndividual = computeAUROC(outGT, outPRED)
    aurocMean = np.array(aurocIndividual).mean()
    return aurocIndividual, aurocMean


In [None]:
logreg = logreg.to(device)

# Initialize optimization components outside the train function
criterion = torch.nn.BCELoss()
optimizer = torch.optim.AdamW(
    logreg.parameters(),
    lr=opt.lr,
    betas=(0.9, 0.999),
    eps=1e-08,
    weight_decay=1e-10  # Lower weight decay for faster convergence
)

# LR Scheduler with warm restarts
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer,
    T_0=10,  # Restart every 10 epochs
    T_mult=2,  # Double the restart interval after each restart
    eta_min=opt.learning_rate_min
)

# Feature normalizer for consistent normalization
feature_normalizer = torch.nn.LayerNorm(logreg.n_inputs).to(device)


In [None]:
best_auc = 0.0
count = 0
patience = 10  # Early stopping patience
plateau_threshold = 0.001  # Minimum improvement to consider progress

# Initialize lists to track metrics
train_losses = []
valid_aucs = []

for epoch in range(opt.epochs):
    losst = Train(logreg, train_loader, optimizer, scheduler, criterion)
    aurocIndividual, aurocMean = Test(logreg, test_loader)
    
    # Save metrics
    train_losses.append(losst)
    valid_aucs.append(aurocMean)
    
    print("Epoch: {},".format(epoch), "Train_loss: {:.3f},".format(losst), "Valid auc: {:.3f}".format(aurocMean))
    
    with open(f'{opt.model_path}{opt.method_name}_logs.txt', 'a') as file:
        file.write(str(epoch)+','+str(aurocMean)+','+str(losst)+'\n')
    
    # Improved model saving logic with plateau detection
    if aurocMean > best_auc + plateau_threshold:
        torch.save(logreg.state_dict(), os.path.join(opt.model_path, f'{opt.method_name}.pth'))
        print('auc increased ({:.3f} --> {:.3f}). Saving model ...'.format(best_auc, aurocMean))
        best_auc = aurocMean
        count = 0  # Reset counter
    else:
        count += 1

    
    # Early stopping check
    if count >= patience:
        print(f"No improvement for {patience} epochs. Early stopping.")
        break

In [None]:
# Plot training curves to visualize convergence
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses)
plt.title('Training Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')

plt.subplot(1, 2, 2)
plt.plot(valid_aucs)
plt.title('Validation AUC')
plt.xlabel('Epochs')
plt.ylabel('AUC')
plt.savefig(f'{opt.model_path}{opt.method_name}_training_curves.png')

In [None]:
# Load best model and evaluate
log_checkpoint = torch.load(os.path.join(opt.model_path, f'{opt.method_name}.pth'), map_location=device)
logreg.load_state_dict(log_checkpoint)
logreg = logreg.to(device)

aurocIndividual, aurocMean = Test(logreg, test_loader)
print(f'Best validation AUC: {aurocMean:.2%}')
print()

formatted_aurocIndividual = [f'{auc:.2%}' for auc in aurocIndividual]

class_names = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 
           'Effusion', 'Emphysema', 'Fibrosis', 'Hernia', 'Infiltration', 'Mass', 
           'No Finding', 'Nodule', 'Pleural_Thickening', 'Pneumonia', 'Pneumothorax']

for i in range(0, len(aurocIndividual)):
    print(class_names[i], ' ', formatted_aurocIndividual[i])

with open(f'{opt.model_path}{opt.method_name}_logs.txt', 'a') as file:
    file.write('\n\n'+'Valid Mean AUC '+f'{aurocMean:.2%}'+'\n\n')    
    for i in range(len(aurocIndividual)):
        file.write(f'{class_names[i]} {formatted_aurocIndividual[i]}\n')