In [1]:
pip install laplace-torch

Note: you may need to restart the kernel to use updated packages.


In [2]:
from prediction.models import ResNet
from prediction.disease_prediction import hp_default_value
import os
import torch
from laplace import Laplace
# from analysis.cross_ds_inference import load_model

In [3]:
def load_model(ckpt_dir):
    model_choose = hp_default_value['model']
    num_classes = hp_default_value['num_classes']
    lr = hp_default_value['lr']
    pretrained = True  # Replace with actual value or source
    model_scale = hp_default_value['model_scale']

    if model_choose == 'resnet':
        model_type = ResNet

    file_list = [f for f in os.listdir(ckpt_dir) if f.endswith('.ckpt')]
    assert len(file_list) == 1, f"Expected 1 checkpoint file, but found {len(file_list)}."
    ckpt_path = os.path.join(ckpt_dir, file_list[0])
    
    model = model_type.load_from_checkpoint(
        ckpt_path,
        num_classes=num_classes,
        lr=lr,
        pretrained=pretrained,
        model_scale=model_scale
    )

    return model


In [63]:
ckpt_dir = "prediction/run/chexpert-Pleural Effusion-fp50-npp1-rs0-image_size224/version_0/checkpoints"
assert os.path.exists(ckpt_dir), f"Checkpoint directory does not exist: {ckpt_dir}"


In [64]:
chexpert_model = load_model(ckpt_dir)
print("CheXpert model loaded successfully.")


Lightning automatically upgraded your loaded checkpoint from v1.9.3 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint prediction/run/chexpert-Pleural Effusion-fp50-npp1-rs0-image_size224/version_0/checkpoints/epoch=8-step=2412.ckpt`


CheXpert model loaded successfully.


In [67]:
chexpert_model.eval


<bound method Module.eval of ResNet(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsampl

In [73]:
la = Laplace(chexpert_model, likelihood="classification", subset_of_weights="all", hessian_structure="diag")

In [75]:
from dataloader.dataloader import CheXpertDataResampleModule

In [82]:
# Define parameters for initialization
img_data_dir = "prediction/run/chexpert-Pleural Effusion-fp50-npp1-rs0-image_size224"
csv_file_img = "datafiles/chexpert.sample.allrace.csv"
image_size = 224
pseudo_rgb = True
batch_size = 32
num_workers = 4
augmentation = True
outdir = "prediction/run/chexpert-Pleural Effusion-fp50-npp1-rs0-image_size224"
version_no = "0"
female_perc_in_training = 50
chose_disease = "Pleural Effusion"
random_state = 42
num_classes = 1
num_per_patient = 1
prevalence_setting = 'separate'
isFlip = False

# Initialize the data module
data_module = CheXpertDataResampleModule(
    img_data_dir=img_data_dir,
    csv_file_img=csv_file_img,
    image_size=image_size,
    pseudo_rgb=pseudo_rgb,
    batch_size=batch_size,
    num_workers=num_workers,
    augmentation=augmentation,
    outdir=outdir,
    version_no=version_no,
    female_perc_in_training=female_perc_in_training,
    chose_disease=chose_disease,
    random_state=random_state,
    num_classes=num_classes,
    num_per_patient=num_per_patient,
    prevalence_setting=prevalence_setting,
    isFlip=isFlip
)

# Get the training dataloader
train_loader = data_module.train_dataloader()

  df = pd.read_csv(self.csv_file_img, header=0)


DEBUG                      Enlarged Cardiomediastinum  Cardiomegaly  Lung Opacity  \
patient_id   sex                                                              
patient00001 Female                         0.0           0.0           0.0   
patient00002 Female                         0.0           0.0           1.0   
patient00003 Male                           0.0           0.0           0.0   
patient00004 Female                         0.0           0.0           0.0   
patient00005 Male                           0.0           0.0           0.0   
...                                         ...           ...           ...   
patient64533 Male                           0.0           0.5           0.0   
patient64534 Male                           0.0           0.0           1.0   
patient64535 Male                           0.0           0.0           1.0   
patient64536 Female                         0.0           0.0           0.0   
patient64537 Male                           0.

  df = pd.read_csv(self.csv_file_img, header=0)


PATIENT WISE disease prevalence
Disease prevalence total: {'Enlarged Cardiomediastinum': 0.10921150971599403, 'Cardiomegaly': 0.17848467862481315, 'Lung Opacity': 0.56273355754858, 'Lung Lesion': 0.0763421773791729, 'Edema': 0.31133221225710017, 'Consolidation': 0.1292351768809168, 'Pneumonia': 0.06086509715994021, 'Atelectasis': 0.2784317389138017, 'Pneumothorax': 0.10784130543099153, 'Pleural Effusion': 0.4078070503238665, 'Pleural Other': 0.029942077727952168, 'Fracture': 0.07450485799701047}
Disease prevalence Female: {'Enlarged Cardiomediastinum': 0.09545136459062281, 'Cardiomegaly': 0.16994401679496152, 'Lung Opacity': 0.5590972708187544, 'Lung Lesion': 0.07655703289013296, 'Edema': 0.313925822253324, 'Consolidation': 0.12393282015395381, 'Pneumonia': 0.05867739678096571, 'Atelectasis': 0.26168649405178446, 'Pneumothorax': 0.10076976906927922, 'Pleural Effusion': 0.40528341497550735, 'Pleural Other': 0.0258222533240028, 'Fracture': 0.06452064380685794}
Disease prevalence Male: {'

Loading Data: 100%|██████████| 17098/17098 [00:03<00:00, 5349.50it/s]
Loading Data: 100%|██████████| 2849/2849 [00:00<00:00, 7999.82it/s]
Loading Data: 100%|██████████| 17100/17100 [00:03<00:00, 4832.76it/s]

#train:  17098
#val:    2849
#test:   17100





In [83]:
la.fit(train_loader)

KeyboardInterrupt: 