In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%pip install -e ~/coding/diffae

In [None]:
import argparse

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from pathlib import Path

from training.data.glioma_public import PublicGliomaDataset
from training.data.mri import extract_slices_from_volume
from training.experiments.cls import ClsModel
from training.experiments.rep import LitModel
from training.templates.templates import gliomapublic_autoenc
from training.templates.templates_cls import gliomapublic_autoenc_cls

In [None]:
CWD = %pwd
CWD = Path(CWD).parent
CWD


In [None]:
SEEED = 0
np.random.seed(SEEED)
torch.manual_seed(SEEED)
print(f"seed = {SEEED}")

In [None]:
def plot_tensor(t, ax, cmap="gray", *args, **kwargs):
    return ax.imshow(t.permute(1, 2, 0).cpu(), cmap=cmap, *args, **kwargs)


In [None]:
args = argparse.Namespace()
args.clf_mode = "multi_class"
args.manipulate_znormalize = False
# args.manipulate_cls = "12"
args.model_name = "beatgans_autoenc"
args.version = "5"  # "2" or "5" for the other model
args.style_ch = "512"
args.use_healthy = True
args

In [None]:
device = 'cuda'
conf = gliomapublic_autoenc(args=args, is_debugging=False)

state = torch.load(CWD / f'{conf.logdir}/last.ckpt', map_location='cpu')
conf.sample_size = state["state_dict"]["x_T"].shape[0]
conf.manipulate_znormalize = False
print(conf.name)
model = LitModel(conf)
model.load_state_dict(state['state_dict'], strict=False)
model.ema_model.eval()
model.ema_model.to(device)
args.pretrain_path = CWD / f"checkpoints/gliomapublic_seq-all/version_{args.version}/last.ckpt"

print("version setup for healthy visualization")
args.version = {"2": "0", "5": "1"}[args.version]
cls_conf = gliomapublic_autoenc_cls(is_debugging=False, args=args)
print()

In [None]:
# define dataset
datasets =[]
for split in ["train","val","test"]:
    print(f"split: {split}")
    datasets.append(PublicGliomaDataset(
    data_dir=conf.data_path,
    img_size=conf.img_size,
    mri_sequences=conf.mri_sequences,
    mri_crop=conf.mri_crop,
    train_mode=conf.train_mode,
    filter_class_labels=True,
    split_ratio=conf.split_ratio,
    split=split,
    manipulate_cls=conf.manipulate_cls,
    use_healthy=conf.use_healthy,
))
train_ds,val_ds, test_ds = datasets
print(f"train: {len(train_ds)}, val: {len(val_ds)}, test: {len(test_ds)}")

# general  hints
- all scans are skull stripped, co-registered, all of the same size (240x240x155).
    - probably sensible to downsample and/or crop
    - reasonable crop size: 96^3 or 64^3 around the center of mass of the tumor
- for each scan we have all 4 MRI sequences: t1, t1c, t2, t2-flair + segmentation (+ sometimes brainmask (=> can also be easily inferred by intensity > 0) )
- not all subjects in the dataset have a preop MRI (in lumiere dataset). they should be ignored
- the dataset is organised in the following
    - ./_CENTERDIR_/_SUBJECTDIR_/preop/_FILENAME_
    - _FILENAME_=sub-_SUBJECTID_\_ses-preop_space-sri\_*SEQ*.nii.gz
        - _SUBJECTID_=_SUBJECTDIR_
        - _SEQ_={t1,t1c,t2,flair,seg}
- background has intensity 0 
- other intensities range from values < 0 to ~1500.
    - multiple possibilities normalize
    - Z-score normalization, [0,1] normalization, [-1,1] normalization (I am currently doing [-1,1])
        - always normalize per sequence
- for visualization of the scans they can extract slices at the center of mass of the tumor
- center of mass of tumor can be found by finding center of mass of (seg_map > 0)

## labels csv (phenoData.csv)
- each row contains labels for tumor type in WHO2021_Int
- subject is identified by the columns Dataset and Patient containing the values  _CENTERDIR_,_SUBJECTDIR_ respectively
- not all subjects have labels
- not all subjects are in the CSV file




## Data splits
(possible splitting with 0.9/0.1 train/val)
split: train
- total: 3006, with labels: 1310
- samples per class [146, 1088, 76]

split: val
- total: 334, with labels: 150
- samples per class [13, 125, 12]

split: test (TCGA center, **this is fixed**)
- total: 243, with labels: 214
- samples per class [54, 139, 21]

train: 1310, val: 150, test: 214