In [33]:
import torch
import pandas as pd
from transformers import AutoFeatureExtractor, SwinModel
from torchvision import transforms
from torch.utils.data import Dataset
import pathlib
import numpy as np
from PIL import Image
from paths import DICT_MIMIC_OBS_TO_INT, DICT_MIMIC_INT_TO_OBS, IMAGES_MIMIC_PATH, MIMIC_PATH_TEST, MIMIC_PATH_TRAIN, MIMIC_PATH_VAL, SWINB_IMAGENET22K_WEIGHTS, DICT_MIMIC_OBSKEY_TO_INT

class MIMICDataset(Dataset):
    def __init__(self, transform, processor, partition, dataset_path, img_root_dir, label_map):
        self.transform = transform
        self.processor = processor
        self.partition = partition
        self.dataset_df = pd.read_csv(dataset_path)
        self.img_root_dir = pathlib.Path(img_root_dir)
        self.label_map = label_map

    def __len__(self):
        return len(self.dataset_df)

    def __getitem__(self, idx):
        #img_name = self.img_root_dir / self.dataset_df.iloc[idx].images.split(",")[0]
        #img = Image.open(img_name).convert("RGB")

        # if isinstance(self.transform, transforms.Compose):
        #     img = self.transform(img)
        # elif isinstance(self.transform, A.core.composition.Compose):
        #     img = self.transform(image=np.array(img))["image"]
        # else:
        #     raise ValueError("Unknown transformation type.")

        # img = self.processor(img, return_tensors="pt", size=384).pixel_values.squeeze()
        # Right now i'm calculating class weights so I'm not using the images
        img = torch.zeros(3, 384, 384)
        labels = torch.tensor([self.label_map[label] for label in self.dataset_df.iloc[idx].labels.split(",")]).long()
        return {"image": img, "labels": labels}
    
    def get_labels_name(self, label_list):
        for i in range(len(label_list)):
            if label_list[i] == 1:
                print(DICT_MIMIC_INT_TO_OBS[i])
                             

In [34]:
from platform import processor


val_train_transform = transforms.Compose([
                transforms.Resize(416),
                transforms.CenterCrop(384),
                ])

processor = AutoFeatureExtractor.from_pretrained(SWINB_IMAGENET22K_WEIGHTS)

train_dataset = MIMICDataset(val_train_transform, processor, "train", MIMIC_PATH_TRAIN, IMAGES_MIMIC_PATH, DICT_MIMIC_OBSKEY_TO_INT)



In [35]:
for i in range(10):
    print("sample ", i, " :", train_dataset[i]["labels"])

sample  0  : tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
sample  1  : tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0])
sample  2  : tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0])
sample  3  : tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0])
sample  4  : tensor([1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0])
sample  5  : tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1])
sample  6  : tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
sample  7  : tensor([2, 0, 0, 1, 2, 0, 0, 0, 0, 1, 0, 0, 0, 0])
sample  8  : tensor([2, 0, 0, 1, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0])
sample  9  : tensor([0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0])


In [36]:
from tqdm import tqdm

def calculate_class_weights(dataset, num_classes=14, num_outcomes=3):
    """
    Calculate weights for each class and outcome based on their frequency in the dataset.

    Args:
        dataset: A PyTorch Dataset object that includes all the data.
        num_classes: Number of classes (abnormalities).
        num_outcomes: Number of possible outcomes for each class.

    Returns:
        class_weights: A tensor of shape (num_classes, num_outcomes) containing the weights.
    """
    # Initialize frequency counters for each class and outcome
    frequency = torch.zeros((num_classes, num_outcomes))
    
    # Iterate through the dataset to compute frequencies
    for sample in tqdm(dataset, desc="Calculating Frequencies", unit="sample"):
        labels = sample["labels"]  # Shape: (num_classes,)
        for i in range(num_classes):
            frequency[i, labels[i]] += 1
    
    # Calculate weights: inverse of frequency, normalized for each class
    # Calculate weights only
    class_weights = 1 / (frequency + 1e-6)  # Add small value to avoid division by zero
    class_weights = class_weights / class_weights.sum(dim=1, keepdim=True)  # Normalize weights per class

    # Print frequency distribution for each class
    for i in range(num_classes):
        print(f"Class {DICT_MIMIC_INT_TO_OBS[i]}: Frequencies {frequency[i].tolist()} -> Weights {class_weights[i].tolist()}")

    return class_weights

In [37]:
class_weights = calculate_class_weights(train_dataset)

Calculating Frequencies: 100%|██████████| 152173/152173 [00:22<00:00, 6671.13sample/s]

Class enlarged cardiomediastinum: Frequencies [123861.0, 8461.0, 19851.0] -> Weights [0.045706793665885925, 0.6691040396690369, 0.28518912196159363]
Class cardiomegaly: Frequencies [109670.0, 35635.0, 6868.0] -> Weights [0.04988563433289528, 0.1535276472568512, 0.7965866923332214]
Class lung opacity: Frequencies [110976.0, 41051.0, 146.0] -> Weights [0.0013092210283502936, 0.003539307741448283, 0.9951515197753906]
Class lung lesion: Frequencies [145867.0, 5613.0, 693.0] -> Weights [0.004210993647575378, 0.10943257063627243, 0.8863564133644104]
Class edema: Frequencies [130480.0, 13905.0, 7788.0] -> Weights [0.03684917092323303, 0.345780611038208, 0.6173702478408813]
Class consolidation: Frequencies [143326.0, 5082.0, 3765.0] -> Weights [0.01486531924456358, 0.4192417860031128, 0.5658929347991943]
Class pneumonia: Frequencies [137846.0, 4949.0, 9378.0] -> Weights [0.0229609664529562, 0.6395387649536133, 0.3375002443790436]
Class atelectasis: Frequencies [109815.0, 35759.0, 6599.0] -> We




In [38]:
from tqdm import tqdm
import torch

def calculate_class_weights(dataset, num_classes=14, num_outcomes=3, low_weight=1e-6):
    """
    Calculate weights for each class and outcome based on their frequency in the dataset.

    Args:
        dataset: A PyTorch Dataset object that includes all the data.
        num_classes: Number of classes (abnormalities).
        num_outcomes: Number of possible outcomes for each class.
        low_weight: The weight to assign to outcomes with 0 occurrences.

    Returns:
        class_weights: A tensor of shape (num_classes, num_outcomes) containing the weights.
    """
    # Initialize frequency counters for each class and outcome
    frequency = torch.zeros((num_classes, num_outcomes))
    
    # Iterate through the dataset to compute frequencies
    for sample in tqdm(dataset, desc="Calculating Frequencies", unit="sample"):
        labels = sample["labels"]  # Shape: (num_classes,)
        for i in range(num_classes):
            frequency[i, labels[i]] += 1
    
    # Initialize the class weights tensor
    class_weights = torch.zeros_like(frequency)
    
    # Calculate weights: inverse of frequency, normalized for each class
    for i in range(num_classes):
        total_freq = frequency[i].sum() + low_weight * (frequency[i] == 0).sum()  # Include low_weight for 0-occurrence outcomes
        for j in range(num_outcomes):
            if frequency[i, j] > 0:
                class_weights[i, j] = 1 / frequency[i, j]
            else:
                class_weights[i, j] = low_weight  # Assign low weight to 0-occurrence outcomes
        class_weights[i] /= class_weights[i].sum()  # Normalize weights per class

    # Print frequency distribution and weights for each class
    for i in range(num_classes):
        print(f"Class {i}: Frequencies {frequency[i].tolist()} -> Weights {class_weights[i].tolist()}")

    return class_weights


In [39]:
class_weights = calculate_class_weights(train_dataset)

Calculating Frequencies:   0%|          | 0/152173 [00:00<?, ?sample/s]

Calculating Frequencies: 100%|██████████| 152173/152173 [00:22<00:00, 6622.37sample/s]

Class 0: Frequencies [123861.0, 8461.0, 19851.0] -> Weights [0.045706793665885925, 0.6691040396690369, 0.28518912196159363]
Class 1: Frequencies [109670.0, 35635.0, 6868.0] -> Weights [0.04988563433289528, 0.1535276472568512, 0.7965866923332214]
Class 2: Frequencies [110976.0, 41051.0, 146.0] -> Weights [0.0013092210283502936, 0.003539307741448283, 0.9951515197753906]
Class 3: Frequencies [145867.0, 5613.0, 693.0] -> Weights [0.004210993647575378, 0.10943257063627243, 0.8863564133644104]
Class 4: Frequencies [130480.0, 13905.0, 7788.0] -> Weights [0.03684917092323303, 0.345780611038208, 0.6173702478408813]
Class 5: Frequencies [143326.0, 5082.0, 3765.0] -> Weights [0.01486531924456358, 0.4192417860031128, 0.5658929347991943]
Class 6: Frequencies [137846.0, 4949.0, 9378.0] -> Weights [0.0229609664529562, 0.6395387649536133, 0.3375002443790436]
Class 7: Frequencies [109815.0, 35759.0, 6599.0] -> Weights [0.048280879855155945, 0.148269385099411, 0.8034497499465942]
Class 8: Frequencies [1




In [40]:
#Save the class weights
torch.save(class_weights, "class_weights.pt")