In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import medmnist
from medmnist import INFO, Evaluator
import torch.utils.data as data
import matplotlib.pyplot as plt
import torch.nn as nn
import torch as t
from DiagnosisAI.models.resnet3d import generate_model
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import pandas as pd
import numpy as np
from DiagnosisAI.utils.metrics import calc_metrics, calculate_type_errors


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
data_flag = 'organmnist3d'
download = True
NUM_EPOCHS = 3
BATCH_SIZE = 3
lr = 0.001
n_classes = 11

info = INFO[data_flag]
DataClass = getattr(medmnist, info['python_class'])

In [4]:
train_dataset = DataClass(split='train',  download=download)
val_dataset = DataClass(split='val',  download=download)
test_dataset = DataClass(split='test',  download=download)

Using downloaded and verified file: /home/michalheit/.medmnist/organmnist3d.npz
Using downloaded and verified file: /home/michalheit/.medmnist/organmnist3d.npz
Using downloaded and verified file: /home/michalheit/.medmnist/organmnist3d.npz


In [5]:
train_labels = train_dataset.labels.flatten()
val_labels = val_dataset.labels.flatten()
test_labels = test_dataset.labels.flatten()
global_labels = np.concatenate([train_labels, val_labels, test_labels])
df = {"global": global_labels, "train": train_labels, "val": val_labels, "test": test_labels}

In [6]:
df = pd.DataFrame(dict([ (k,pd.Series(v)) for k,v in df.items() ]))
for col in df.columns:
    df[col] = df[col].astype('Int64')

temp = df.iloc[:, 0].value_counts().to_frame()
for i in range(1, df.columns.shape[0]):
    s = df.iloc[:, i].value_counts()
    temp = pd.concat([temp, s], axis=1)

temp['class_idx'] = temp.index.values
temp.sort_index(inplace=True)

In [40]:
map = temp['class_idx'].astype(str)
labels = info['label']
labels = map.replace(labels)

In [43]:
pl_map = {'liver': 'wątroba', 'kidney-right':"nerka prawa",
            'kidney-left': 'nerka lewa',
            'femur-right': 'kość udowa - prawa',
            'femur-left': 'kość udowa - lewa',
            'bladder': 'pęcherz moczowy',
            'heart': 'serce',
            'lung-right': 'płuco prawe',
            'lung-left': 'płuco lewe',
            'spleen': 'śledziona',
            'pancreas': 'trzustka'
            }

In [45]:
specs = [[{'type':'domain'}, {'type':'domain'}], [{'type':'domain'}, {'type':'domain'}]]
fig = make_subplots(rows=2, cols=2, specs=specs, subplot_titles=["Global", "Train", "Val", "Test"])
fig.add_trace(go.Pie(labels=labels.replace(pl_map), values=temp['global'].values, name="Global"),
              1, 1)
fig.add_trace(go.Pie(labels=labels.replace(pl_map), values=temp['train'].values, name="Train"),
              1, 2)
fig.add_trace(go.Pie(labels=labels.replace(pl_map), values=temp['val'].values, name="Valid"),
              2, 1)
fig.add_trace(go.Pie(labels=labels.replace(pl_map), values=temp['test'].values, name="Test"),
              2, 2)

fig.update_traces(hoverinfo='label+percent+name', textinfo='percent')
fig.update(layout_title_text='Rozkład klas w zbiorach',
           layout_showlegend=True)
fig.update_layout(height=800, width=1000)

fig = go.Figure(fig)
fig.show()