In [1]:
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from torch import nn
from torch.optim import AdamW
from torchvision import transforms
from transformers import AutoImageProcessor, SwinModel
import pandas as pd
import pathlib
from PIL import Image
import numpy as np
import albumentations as A
import cv2
import multiprocessing
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from paths import DICT_MIMICALL_OBS_TO_INT, IMAGES_MIMIC_PATH, SWINB_IMAGENET22K_WEIGHTS, DICT_MIMIC_OBSKEY_TO_INT, DICT_MIMICALL_INT_TO_OBS


  from .autonotebook import tqdm as notebook_tqdm
  check_for_updates()


In [2]:
class SwinLightningModel(pl.LightningModule):
    def __init__(self, swin_weights, num_classes=14, lr=1e-4, weight_decay=0.05, epochs=30):
        super().__init__()
        self.swin = SwinModel.from_pretrained(swin_weights)
        self.processor = AutoImageProcessor.from_pretrained(swin_weights)
        self.classifier = nn.Linear(self.swin.config.hidden_size, num_classes * 2)
        self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
        self.lr = lr
        self.weight_decay = weight_decay
        self.epochs = epochs
        self.test_step_outputs = []

    def forward(self, x):
        x = self.swin(x).pooler_output
        x = self.classifier(x)
        return x.view(-1, 14, 2)

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat.view(-1, 2), y.view(-1))
        preds = torch.argmax(y_hat, dim=2)
        self.test_step_outputs.append({"loss": loss
                                        , "preds": preds
                                        , "y": y})
        return {"loss": loss, "preds": preds, "y": y}
    
model = SwinLightningModel(SWINB_IMAGENET22K_WEIGHTS)
model.eval()
checkpoint_path = "lightning_logs/version_3/checkpoints/swin_best.ckpt"
model.load_state_dict(torch.load(checkpoint_path)["state_dict"])

  model.load_state_dict(torch.load(checkpoint_path)["state_dict"])


<All keys matched successfully>

In [3]:
#Save the swin 

model.swin.save_pretrained("swin_mimic")

In [20]:
class MIMICDataset(Dataset):
    def __init__(self, transform, processor, partition, dataset_path, img_root_dir, label_map, labels):
        self.transform = transform
        self.processor = processor
        self.partition = partition
        self.dataset_df = pd.read_csv(dataset_path)
        if partition == "train":
            self.dataset_df = self.dataset_df[self.dataset_df["split"] == "train"]
        elif partition == "val":
            self.dataset_df = self.dataset_df[self.dataset_df["split"] == "validate"]
        elif partition == "test":
            self.dataset_df = self.dataset_df[self.dataset_df["split"] == "test"]
        else:
            raise ValueError("Unknown partition type.")
        
        self.img_root_dir = pathlib.Path(img_root_dir)
        self.label_map = label_map
        self.possible_labels = list(labels.keys())

    def __len__(self):
        return len(self.dataset_df)

    def __getitem__(self, idx):
        img_name = self.img_root_dir / self.dataset_df.iloc[idx].image_path.split(",")[0]
        img = Image.open(img_name).convert("RGB")

        if isinstance(self.transform, transforms.Compose):
            img = self.transform(img)
        elif isinstance(self.transform, A.core.composition.Compose):
            img = self.transform(image=np.array(img))["image"]
        else:
            raise ValueError("Unknown transformation type.")

        img = self.processor(img, return_tensors="pt", size=384).pixel_values.squeeze()
        row = self.dataset_df.iloc[idx]
        labels = torch.zeros(14)
        for i in range(len(self.possible_labels)):
            inte_label = row[self.possible_labels[i]]
            if inte_label != inte_label:
                inte_label = -2
            labels[i] = self.label_map[inte_label]
        labels = labels.long()
        return img, labels


In [21]:
BATCH_SIZE = 24
num_workers = multiprocessing.cpu_count() - 1

In [22]:
def val_test_transforms():
    return transforms.Compose([
        transforms.Resize(416),
        transforms.CenterCrop(384)
    ])

In [23]:
test_dataset = MIMICDataset(transform=val_test_transforms(),
                            processor=model.processor,
                            partition="test",
                            dataset_path="mimic_all_with_image_paths.csv",
                            img_root_dir=IMAGES_MIMIC_PATH,
                            label_map=DICT_MIMIC_OBSKEY_TO_INT,
                            labels=DICT_MIMICALL_OBS_TO_INT)

In [24]:
test_loader = DataLoader(test_dataset, 
                         batch_size=BATCH_SIZE, 
                         num_workers=num_workers, 
                         shuffle=False)

In [25]:
from tqdm import tqdm

model.to("cuda")
model.eval()

model.test_step_outputs = []
with torch.no_grad():
    for batch in tqdm(test_loader):
        batch = [b.to("cuda") for b in batch]
        model.test_step(batch, 0)

preds = torch.cat([o["preds"] for o in model.test_step_outputs], dim=0)
y = torch.cat([o["y"] for o in model.test_step_outputs], dim=0)

  0%|          | 0/137 [00:00<?, ?it/s]

100%|██████████| 137/137 [00:21<00:00,  6.34it/s]


In [26]:
# Let's calculate the accuracy for each labels and the overall accuracy
accs = []
for i in range(14):
    acc = (preds[:, i] == y[:, i]).float().mean() * 100
    accs.append(acc.item())
    print(f"Accuracy for {DICT_MIMICALL_INT_TO_OBS[i]}: {acc.item()}")
print(f"Overall accuracy: {np.mean(accs)}")
    

Accuracy for Atelectasis: 73.14163208007812
Accuracy for Cardiomegaly: 73.41694641113281
Accuracy for Consolidation: 90.27226257324219
Accuracy for Edema: 77.4242935180664
Accuracy for Enlarged Cardiomediastinum: 89.14041137695312
Accuracy for Fracture: 97.15509796142578
Accuracy for Lung Lesion: 95.1361312866211
Accuracy for Lung Opacity: 69.53197479248047
Accuracy for No Finding: 84.613037109375
Accuracy for Pleural Effusion: 80.36096954345703
Accuracy for Pleural Other: 97.58336639404297
Accuracy for Pneumonia: 80.20801544189453
Accuracy for Pneumothorax: 96.11502075195312
Accuracy for Support Devices: 82.89997863769531
Overall accuracy: 84.78565270560128


In [27]:
#Instead of using the accuracy, we can also calculate the F1 score
from sklearn.metrics import f1_score

f1_scores = []
for i in range(14):
    f1 = f1_score(y[:, i].cpu(), preds[:, i].cpu(), average="macro")
    f1_scores.append(f1)
    print(f"F1 score for {DICT_MIMICALL_INT_TO_OBS[i]}: {f1}")
print(f"Overall F1 score: {np.mean(f1_scores)}")

F1 score for Atelectasis: 0.5620230626776975
F1 score for Cardiomegaly: 0.631748395974429
F1 score for Consolidation: 0.4775540768735928
F1 score for Edema: 0.6801570020384917
F1 score for Enlarged Cardiomediastinum: 0.4712922529516416
F1 score for Fracture: 0.49278510473235065
F1 score for Lung Lesion: 0.5279722470700953
F1 score for Lung Opacity: 0.5868502637317123
F1 score for No Finding: 0.6899850885027247
F1 score for Pleural Effusion: 0.7873864648516171
F1 score for Pleural Other: 0.4938845022449296
F1 score for Pneumonia: 0.5638461593980993
F1 score for Pneumothorax: 0.670887169718287
F1 score for Support Devices: 0.8142908289982044
Overall F1 score: 0.6036187585545624


In [28]:
# Save accuracy and F1 scores to a file
with open("accuracy_f1_scores_3.txt", "w") as f:
    f.write(f"Overall accuracy: {np.mean(accs)}\n")
    f.write(f"Overall F1 score: {np.mean(f1_scores)}\n")
    for i in range(14):
        f.write(f"Accuracy for {DICT_MIMICALL_INT_TO_OBS[i]}: {accs[i]}\n")
        f.write(f"F1 score for {DICT_MIMICALL_INT_TO_OBS[i]}: {f1_scores[i]}\n")

# Save the predictions wtih the ground truth labels to a file
with open("predictions_3.txt", "w") as f:
    for i in range(len(preds)):
        f.write(f"Prediction: {preds[i].cpu().numpy()} Ground truth: {y[i].cpu().numpy()}\n")