In [None]:

import torch
import torch.optim as optim
import torch
from torch.utils.data import DataLoader
import pandas as pd

import segmentation_models_pytorch as smp

import lightning as L
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
import torchmetrics as tm

from config import Config
from dataset import *
from utils import DiceCELossSplitter, plot_img_label_pred, AggregateTestingResultsCallback
from transforms import ImageVisualizer
import wandb

import monai
from monai.transforms import *

from tabulate import tabulate
import matplotlib.pyplot as plt

seed_everything(99, workers=True)


# globally set source domains for SVDNA and for dataset preppning
cfg = Config(source_domains = ["Spectralis", "Topcon", "Cirrus"])
cfg.batch_size = 8
cfg.epochs = 100

In [None]:
train_data, val_data, test_data = OCTDatasetPrep(cfg.train_dir,
                                                 source_domains = cfg.source_domains
                                                ).get_datasets(dataset_split=[0.8, 0.2], use_official_testset=True)

train_dataset = MakeDataset(train_data, cfg.train_transforms)
val_dataset = MakeDataset(val_data, cfg.val_transforms)
test_dataset = MakeDataset(test_data, cfg.test_transforms)

train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=7, persistent_workers=True)
val_loader = DataLoader(val_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=7, persistent_workers=True)
test_loader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=7, persistent_workers=True)


In [None]:
means = torch.zeros(3)
stds = torch.zeros(3)
for sample in train_loader:
    label = sample['masks'][0, 1:, :, :]

    # take average of each channel separately
    mean_pixels = label.mean(dim=[1, 2])
    std_pixels = label.std(dim=[1, 2])

    means += mean_pixels
    stds += std_pixels

means /= len(train_loader)
stds /= len(train_loader)

print(means*100, stds)

In [None]:

# go through the entire training dataset, always ignoring channel 0, and collect some statistics.
# the image masks have the properties that the pixels are either 0 or 1 and pixels == 1 are much less than pixels == 0
# 1. how many images have more than 1% of pixels == 1 in the masks channels 1, 2 or 3
# 2. in the channels 1, 2 or 3, what is the average percentage of pixels == 1
# 3. what is the average percentage of pixels == 1 in the entire mask
# 4. all positive ( == 1) pixels in the first channel divided by all positive pixels in channels 1, 2 and 3.

@torch.no_grad()
def calculate_statistics(loader):
    total_images = 0
    images_with_more_than_1_percent = 0
    total_percentage_in_channels_123 = 0
    total_percentage_in_all_channels = 0
    total_positive_in_channel_1 = 0
    total_positive_in_channels_123 = 0

    for images, masks in loader:
        total_images += images.shape[0]

        images = images.cpu()
        masks = masks.cpu()

        # Calculate the percentage of positive pixels in each mask
        percentages = masks.mean(dim=[2, 3])  # Shape: (batch_size, num_channels)

        # Calculate the number of images with more than 1% of positive pixels in channels 1, 2, or 3
        images_with_more_than_1_percent += (percentages[:, 1:] > 0.01).any(dim=1).sum().item()

        # Calculate the average percentage of positive pixels in channels 1, 2, and 3
        total_percentage_in_channels_123 += percentages[:, 1:].mean().item()

        # Calculate the average percentage of positive pixels in all channels
        total_percentage_in_all_channels += percentages.mean().item()

        # Calculate the total number of positive pixels in channel 1 and channels 1, 2, and 3
        total_positive_in_channel_1 += masks[:, 1].sum().item()
        total_positive_in_channels_123 += masks[:, 1:].sum().item()

    # Calculate the average percentages
    average_percentage_in_channels_123 = total_percentage_in_channels_123 / total_images
    average_percentage_in_all_channels = total_percentage_in_all_channels / total_images

    # Calculate the ratio of positive pixels in channel 1 to positive pixels in channels 1, 2, and 3
    ratio = total_positive_in_channel_1 / total_positive_in_channels_123

    return {
        "Total images": total_images,
        "Images with more than 1% of positive pixels in channels 1, 2, or 3": images_with_more_than_1_percent,
        "Average percentage of positive pixels in channels 1, 2, or 3": average_percentage_in_channels_123,
        "Average percentage of positive pixels in all channels": average_percentage_in_all_channels,
        "Ratio of positive pixels in channel 1 to positive pixels in channels 1, 2, and 3": ratio,
    }

statistics = calculate_statistics(train_loader)
for key, value in statistics.items():
    print(f"{key}: {value}")