In [64]:
import sys
sys.path.append('../..')

import os
import json
import shutil
from typing import Dict
import re
import glob
import argparse
from collections import defaultdict
import warnings

import numpy as np
import scipy.ndimage
import progress.bar
import imageio
import matplotlib.pyplot as plt
import pandas as pd

import cr_interface as cri

In [153]:
class CrCollection:
    def __init__(self, df, copy=False):
        if copy:
            self.df = df.copy()
        else:
            self.df = df

    @classmethod
    def from_dict(cls, d):
        dict_of_series = defaultdict(list)
        dict_of_series['cr_code'] = list(d.keys())

        keys = ['label', 'original_name', 'original_filepath']
        for info in d.values():
            for key in keys:
                dict_of_series[key].append(info.get(key, ''))

        cr_keys = ['dataset_index', 'pid', 'phase_index', 'slice_index']
        for i, key in enumerate(cr_keys):
            for cr_code in dict_of_series['cr_code']:
                dict_of_series[key].append(cri.parse_cr_code(cr_code)[i])

        index = ['cr_code'] + cr_keys + keys
        df = pd.DataFrame.from_dict(dict_of_series)[index]
        df.sort_values('cr_code')
        return cls(df)
    
    @classmethod
    def load(cls):
        '''
        Load all data from cr_metadata.json
        '''
        return cls.from_dict(cri.load_metadata())
    
    def split_by(self, columns, ratios, copy=False):
        cr_keys = ['dataset_index', 'pid', 'phase_index', 'slice_index']
        ratios = pd.Series(ratios)
        
        if ratios.sum() != 1:
            raise ValueError('sum of ratio values are not 1')
        
        for column in columns:
            if column not in cr_keys:
                raise ValueError('invalid column {}'.format(column))
                
        key_df = self.df.loc[:, columns].drop_duplicates()
        key_df = key_df.reindex(np.random.permutation(key_df.index), copy=False)
        key_df = key_df.sort_index()
        
        lower_bounds = pd.Series([0] + list(ratios)[:-1]).cumsum()
        upper_bounds = ratios.cumsum()
        splits = []
        for lower, upper in zip(lower_bounds, upper_bounds):
            split = key_df.iloc[int(lower * len(key_df)):int(upper * len(key_df))]
            df = self.df
            for column in columns:
                df = df.loc[df[column].isin(split[column])]
            splits.append(CrCollection(df, copy))
        
        return splits
    
    def filter_by(self, in_place=False, **kwargs):
        '''
        kwargs
        column_name: list_of_possible_values
        '''
        if in_place:
            df = self.df
        else:
            df = self.df.copy()
            
        for key, vals in kwargs.items():
            df = df.loc[df[key].isin(vals)]
            
        df = df.sort_values('cr_code').reset_index(drop=True)
        
        if not in_place:
            return CrCollection(df)
        else:
            self.df = df
    
    def labeled(self, in_place=False):
        if in_place:
            df = self.df
        else:
            df = self.df.copy()
            
        df = df.loc[df['label'] != '']
            
        df = df.sort_values('cr_code').reset_index(drop=True)
        
        if not in_place:
            return CrCollection(df)
        else:
            self.df = df
            
    def tri_label(self, in_place=False):
        def to_tri_label(label):
            if label in ['ap', 'md', 'bs']:
                return 'in'
            else:
                return label
            
        if in_place:
            df = self.df
        else:
            df = self.df.copy()
            
        df.loc[:, 'label'] = df.loc[:, 'label'].apply(to_tri_label)
        
        if not in_place:
            return CrCollection(df)
        else:
            self.df = df
    
    def get_cr_codes(self):
        return list(self.df['cr_code'])
    
    def get_cr_codes_by_label(self):
        df = self.labeled(in_place=False).df
        labels = list(df.loc[:, 'label'].drop_duplicates())
        cr_codes = dict()
        
        for label in labels:
            cr_codes[label] = list(df.loc[df.loc[:, 'label']==label]['cr_code'])
        
        return cr_codes
    
    def __add__(self, other):
        if isinstance(other, CrCollection):
            return CrCollection(pd.concat(self.df, other.df, copy=False))
        else:
            raise TypeError('cannot add CrCollection with {}'.format(type(other)))

In [155]:
collection = CrCollection.load().labeled()
tri_collection = collection.tri_label()

import pprint
pp = pprint.PrettyPrinter(indent=4)
splits = tri_collection.filter_by(dataset_index=[0]).split_by(['dataset_index', 'pid'], [0.2]*5)
for split in splits:
    print(len(split.get_cr_codes()))

508
492
492
516
514


In [113]:
collection.filter_by(dataset_index=[0]).labeled().tri_label()
collection.filter_by(dataset_index=[1], in_place=False).labeled().tri_label().get_cr_codes_by_label()
collection.filter_by(dataset_index=[1], in_place=False).labeled().tri_label().get_cr_codes_by_label()

{'oap': ['D01_P00000301_P00_S00',
  'D01_P00000301_P00_S01',
  'D01_P00000301_P00_S02',
  'D01_P00000301_P14_S00',
  'D01_P00000301_P14_S01',
  'D01_P00000301_P14_S02',
  'D01_P00000401_P00_S00',
  'D01_P00000401_P00_S01',
  'D01_P00000401_P00_S02',
  'D01_P00000401_P00_S03',
  'D01_P00000401_P14_S00',
  'D01_P00000401_P14_S01',
  'D01_P00000401_P14_S02',
  'D01_P00000401_P14_S03',
  'D01_P00000601_P00_S07',
  'D01_P00000601_P14_S07',
  'D01_P00000701_P00_S00',
  'D01_P00000701_P14_S00',
  'D01_P00000701_P14_S01',
  'D01_P00000901_P00_S15',
  'D01_P00000901_P14_S15',
  'D01_P00007301_P00_S11',
  'D01_P00007301_P00_S12',
  'D01_P00007301_P00_S13',
  'D01_P00007301_P14_S10',
  'D01_P00007301_P14_S11',
  'D01_P00007301_P14_S12',
  'D01_P00007301_P14_S13',
  'D01_P00007401_P00_S12',
  'D01_P00007401_P14_S11',
  'D01_P00007401_P14_S12',
  'D01_P00007501_P00_S10',
  'D01_P00007501_P00_S11',
  'D01_P00007501_P14_S09',
  'D01_P00007501_P14_S10',
  'D01_P00007501_P14_S11',
  'D01_P00007601_P00_