In [None]:
import PIL
import pandas as pd

import torch

import robustdg_modified.dataset as dataset
import robustdg_modified.config as cfg

torch.__version__

In [None]:
import sys
print(sys.version)
print(f"Num GPUs Available: {torch.cuda.device_count()}")

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_device = torch.device(device)
torch_device

## Reproducibility

In [None]:
SEED = 1

data_loader_generator = torch.Generator()
cfg.reproducibility.seed_everything(SEED, data_loader_generator)

## Dataset

In [None]:
TRAIN_KEY = "augmented_train"  # "train" or "augmented_train"

### Train Validation

In [None]:
labels_csv = pd.read_csv(cfg.paths.LABELS_CSV[TRAIN_KEY])
domain_csv = pd.read_csv(cfg.paths.DOMAIN_TRAIN_CSV[TRAIN_KEY])

In [None]:
train_index, validation_index = dataset.get_split_train_validation_index(labels_csv.index, 0.80)

In [None]:
train_labels_csv, train_domain_csv = dataset.get_only_desired_indexes(train_index, labels_csv, domain_csv)

train_img_names = dataset.read.get_image_names(train_labels_csv)
train_img_labels = dataset.read.get_one_hot_labels(train_labels_csv)
train_img_domain = dataset.read.get_one_hot_domain(train_domain_csv)

In [None]:
val_labels_csv, val_domain_csv = dataset.get_only_desired_indexes(validation_index, labels_csv, domain_csv)

val_img_names = dataset.read.get_image_names(val_labels_csv)
val_img_labels = dataset.read.get_one_hot_labels(val_labels_csv)
val_img_domain = dataset.read.get_one_hot_domain(val_domain_csv)

### Test

In [None]:
test_labels_csv = pd.read_csv(cfg.paths.LABELS_CSV["test"])

test_img_names = dataset.read.get_image_names(test_labels_csv)
test_img_labels = dataset.read.get_one_hot_labels(test_labels_csv)

### Dataset information

In [None]:
CLASSES = dataset.utils.metadata.get_one_hot_encoded_names(train_img_labels)
DOMAINS = dataset.utils.metadata.get_one_hot_encoded_names(train_img_domain)

CLASSES, DOMAINS

In [None]:
IMG_CHANNELS, IMG_HEIGHT, IMG_WIDTH = dataset.utils.metadata.get_image_dimensions(cfg.paths.IMG_DIR["train"])
IMG_CHANNELS, IMG_HEIGHT, IMG_WIDTH

### Domain information

In [None]:
diagnosis = dataset.utils.one_hot_encoding.convert_one_hot_df_to_names(train_img_labels, "diagnosis")
diagnosis_method = dataset.utils.one_hot_encoding.convert_one_hot_df_to_names(train_img_domain, "diagnosis method")
both = pd.concat([diagnosis, diagnosis_method], axis=1)

diagnosis_count_per_method = both.groupby("diagnosis method")["diagnosis"].value_counts()
diagnosis_count_per_method

## RobustDG Parameters

In [None]:
args = cfg.args_mock.ArgsMock(
    out_classes = CLASSES.size,
    img_c = IMG_CHANNELS,
    img_h = IMG_HEIGHT,
    img_w = IMG_WIDTH,
    batch_size = cfg.hparams.BATCH_SIZE,
)

## Dataset

In [None]:
train = dataset.create_robustdg_train_dataset(
    args = args, 
    img_dir = cfg.paths.IMG_DIR[TRAIN_KEY], 
    int_to_img_names = train_img_names, 
    labels_df = train_img_labels, 
    domain_df = train_img_domain,
    transform = None,
)

validation = dataset.create_robustdg_train_dataset(
    args = args, 
    img_dir = cfg.paths.IMG_DIR[TRAIN_KEY], 
    int_to_img_names = val_img_names, 
    labels_df = val_img_labels, 
    domain_df = val_img_domain,
    transform = None,
)

test = dataset.create_robustdg_test_dataset(
    args=args,
    img_dir = cfg.paths.IMG_DIR["test"], 
    int_to_img_names = test_img_names, 
    labels_df = test_img_labels, 
    transform = None,
)

## Dataset Samples

In [None]:
train_filenames = cfg.paths.IMG_DIR[TRAIN_KEY].glob("*.jpg")

filename = next(train_filenames)
# PIL.Image.open(filename)

In [None]:
dataset.utils.plot_samples.plot_some_train_samples(2, 2, train, CLASSES, DOMAINS)

## Dataloaders

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(
    train, 
    batch_size=args.batch_size, 
    shuffle=True, 
    worker_init_fn=cfg.reproducibility.seed_worker, 
    generator=data_loader_generator
)

val_dataloader = DataLoader(
    validation, 
    batch_size=args.batch_size, 
    shuffle=False, 
    worker_init_fn=cfg.reproducibility.seed_worker, 
    generator=data_loader_generator
)

test_dataloader = DataLoader(
    test, 
    batch_size=args.batch_size, 
    shuffle=False, 
    worker_init_fn=cfg.reproducibility.seed_worker, 
    generator=data_loader_generator
)

data_loaders = {
    "train": train_dataloader,
    "validation": val_dataloader,
    "test": test_dataloader,
}