In [None]:
%pip install pandas

In [None]:
from deep_mri.dataset import DEFAULT_PATH, CLASS_NAMES
from deep_mri.dataset.dataset import get_image_id
def _merge_items(dictionary):
    items = []
    for key in dictionary.keys():
        items += dictionary[key]
    return items
def get_image_id(name):
    return int(re.search('_image_id_([0-9]*)', name).group(1))


def get_image_group(file_path, class_folder=3):
    parts = file_path.split(os.path.sep)
    return parts[class_folder] == CLASS_NAMES

def get_train_valid_files(path=DEFAULT_PATH,
                          csv_path='/ADNI/ADNI1_Complete_1Yr_1.5T_10_13_2019.csv',
                          train_filter_first_screen=True,
                          valid_filter_first_screen=False,
                          valid_train_ratio=0.2,
                          shuffle=False,
                          dropping_groups=[],
                          im_id_fnc=get_image_id,
                          img_group_fnc=get_image_group):
    assert len(dropping_groups) <= 1, "Less than 2 groups remains"
    files_list = glob.glob(path)
    # meta info
    df = pd.read_csv(csv_path)
    df = df.set_index('Image Data ID')
    df['Group'] = df['Group'].str.lower()
    meta_info = df[['Visit', 'Group', 'Subject']].to_dict('index')

    # Split into groups by subject id
    subjects = {c: [] for c in CLASS_NAMES}
    for f in files_list:
        image_id = int(im_id_fnc(f))
        target = CLASS_NAMES[np.argmax(img_group_fnc(f))]
        assert target == meta_info[image_id]['Group']
        subject = meta_info[image_id]['Subject']
        visit = meta_info[image_id]['Visit']
        if visit == 1:
            subjects[target].append(subject)

    # Shuffle
    rnd = random.Random(42)
    if shuffle:
        for group in subjects:
            rnd.shuffle(group)

    # Count groups
    groups_count = np.array([len(subjects[key]) for key in subjects.keys()])
    for count, group in zip(groups_count, subjects.keys()):
        logging.warning(f'{group.upper()} count: {count}')

    # Split Subjects into train valid groups
    valid_sizes = np.ceil(groups_count * valid_train_ratio).astype(int)
    train_subjects = {key: subjects[key][valid_size:] for key, valid_size in zip(subjects.keys(), valid_sizes)}
    valid_subjects = {key: subjects[key][:valid_size] for key, valid_size in zip(subjects.keys(), valid_sizes)}

    # Groups changed after visits
    train_subjects = _merge_items(train_subjects)
    valid_subjects = _merge_items(valid_subjects)

    train_files = []
    valid_files = []
    for f in files_list:
        image_id = int(im_id_fnc(f))
        target = CLASS_NAMES[np.argmax(img_group_fnc(f))]
        assert target == meta_info[image_id]['Group']
        subject = meta_info[image_id]['Subject']
        visit = meta_info[image_id]['Visit']
        # Drop unwanted groups
        if target in dropping_groups:
            continue
        if subject in train_subjects:
            if train_filter_first_screen and visit != 1:
                continue
            train_files.append(f)
        elif subject in valid_subjects:
            if valid_filter_first_screen and visit != 1:
                continue
            valid_files.append(f)
        else:
            assert visit != 1, "None seen imgs"
            logging.error(f"Image {image_id} without first visit, subject {subject}")
            if train_filter_first_screen:
                logging.error(f"{image_id} appending to train set")
                train_files.append(f)

    return train_files, valid_files

In [None]:
def encoder_fc(encoder_model_path,
               pretrained_layers=['conv1','maxp1', 'conv2','maxp2', 'conv3', 'maxp3'],
               input_shape=(93, 115, 93, 1),
               fc_units=800):
    pretrained_model = tf.keras.models.load_model(encoder_model_path)
    encoder_layer = []
    for n in pretrained_layers:
        layer = pretrained_model.get_layer(n)
        if n.startswith('conv'):
            encoder_layer.append(tf.keras.layers.Conv3D.from_config(layer.get_config()))
        elif n.startswith('maxp'):
            encoder_layer.append(tf.keras.layers.MaxPool3D.from_config(layer.get_config()))

    encoder_layer.append(tf.keras.layers.Flatten())
    encoder_layer.append(tf.keras.layers.Dense(fc_units, activation='relu'))
    encoder_layer.append(tf.keras.layers.Dense(3, activation='softmax'))

    model = tf.keras.Sequential( [tf.keras.layers.Input(shape=input_shape)] + encoder_layer) 

    for n in pretrained_layers:
        w = pretrained_model.get_layer(n).get_weights()
        layer = model.get_layer(n)
        layer.set_weights(w)
        layer.trainable = False


    return model

In [None]:
pretrained_layers = ['conv1','maxp1', 'conv2','maxp2', 'conv3', 'maxp3']
model = encoder_fc('encoder_test', pretrained_layers)

In [None]:
model.get_layer('conv1').get_config()

In [None]:
%pip install auto-tqdm

In [None]:
train_files, valid_files = get_train_valid_files(train_filter_first_screen=True, valid_filter_first_screen=True)

In [None]:
len(train_files)

In [None]:
  "dataset_path" : "default",
  "train_filter_first_scan" : true,
  "valid_filter_first_scan" : true,
  "dataset" : "encoder",
  "dataset_args" : {

  },

In [1]:
from deep_mri.dataset import dataset_factory

In [4]:
train_ds, valid_ds = dataset_factory('3d', True, True,**{
    "normalize" : True,
    "downscale_ratio" : 2
  } )

ERROR:root:Image 73903 without first visit, subject 021_S_0626
ERROR:root:Image 68088 without first visit, subject 027_S_0644
ERROR:root:Image 67941 without first visit, subject 021_S_0141
ERROR:root:Image 47115 without first visit, subject 027_S_0461
ERROR:root:Image 65999 without first visit, subject 027_S_0256
ERROR:root:Image 67918 without first visit, subject 011_S_0861
ERROR:root:Image 63236 without first visit, subject 099_S_0551
ERROR:root:Image 68032 without first visit, subject 021_S_0424
ERROR:root:Image 88086 without first visit, subject 127_S_0393
ERROR:root:Image 86025 without first visit, subject 005_S_0324
ERROR:root:Image 95674 without first visit, subject 100_S_0190
ERROR:root:Image 79171 without first visit, subject 021_S_0332
ERROR:root:Image 80659 without first visit, subject 100_S_0296
ERROR:root:Image 69472 without first visit, subject 021_S_0332
ERROR:root:Image 86179 without first visit, subject 027_S_0307
ERROR:root:Image 63574 without first visit, subject 127