In [None]:
import glob

In [None]:
image_list = sorted(glob.glob('/workspace/sunggu/1.Hemorrhage/SMART-Net/datasets/Images/*.nii'))
label_list = sorted(glob.glob('/workspace/sunggu/1.Hemorrhage/SMART-Net/datasets/Labels/*.nii'))

In [None]:
image_saver = SaveImage(output_dir='/workspace/sunggu/1.Hemorrhage/SMART-Net/datasets/new_image', 
                output_postfix='Image', 
                output_ext='.nii.gz', 
                resample=True, 
                mode='bilinear', 
                squeeze_end_dims=True, 
                data_root_dir='', 
                separate_folder=False, 
                print_log=True)

label_saver = SaveImage(output_dir='/workspace/sunggu/1.Hemorrhage/SMART-Net/datasets/new_label', 
                output_postfix='label', 
                output_ext='.nii.gz', 
                resample=True, 
                mode='nearest', 
                squeeze_end_dims=True, 
                data_root_dir='', 
                separate_folder=False, 
                print_log=True)

In [None]:
from monai.transforms import *

for image_p, label_p in zip(image_list, label_list):

    # Load nii data
    image_p, img_meta = LoadImage()(image_p)
    image_p = AddChannel()(image_p)
    
    label_p, label_meta = LoadImage()(label_p)
    label_p = AddChannel()(label_p)
    
    if label_p.max() == 0:
        img_meta['filename_or_obj']   = img_meta['filename_or_obj'].replace('.nii', '_normal_img.nii')
        label_meta['filename_or_obj'] = label_meta['filename_or_obj'].replace('.nii', '_normal_mask.nii')
    else :
        img_meta['filename_or_obj']   = img_meta['filename_or_obj'].replace('.nii', '_hemo_img.nii')
        label_meta['filename_or_obj'] = label_meta['filename_or_obj'].replace('.nii', '_hemo_mask.nii')
                
    image_saver(image_p, img_meta)    # Note: image should be channel-first shape: [C,H,W,[D]].
    label_saver(label_p.astype('bool').astype('float'), label_meta)    # Note: image should be channel-first shape: [C,H,W,[D]].


In [None]:
!nvidia-smi

In [None]:
pwd

In [None]:
cd '/workspace/sunggu/1.Hemorrhage/SMART-Net/'

# Upstream

## SMART-Net

In [None]:
!python inference.py \
--data-folder-dir '/workspace/sunggu/1.Hemorrhage/SMART-Net/datasets/samples' \
--test-dataset-name 'Custom' \
--slice-wise-manner "True" \
--model-name 'Up_SMART_Net' \
--num-workers 4 \
--pin-mem \
--training-stream 'Upstream' \
--multi-gpu-mode 'Single' \
--cuda-visible-devices '2' \
--print-freq 1 \
--output-dir '/workspace/sunggu/1.Hemorrhage/SMART-Net/checkpoints/up_test' \
--resume '/workspace/sunggu/1.Hemorrhage/SMART-Net/checkpoints/up_test/epoch_0_checkpoint.pth'


# Downstream

## Down_SMART_Net_CLS

In [None]:
!python inference.py \
--data-folder-dir '/workspace/sunggu/1.Hemorrhage/SMART-Net/datasets/samples' \
--test-dataset-name 'Custom' \
--slice-wise-manner "False" \
--model-name 'Down_SMART_Net_CLS' \
--num-workers 4 \
--pin-mem \
--training-stream 'Downstream' \
--multi-gpu-mode 'Single' \
--cuda-visible-devices '2' \
--print-freq 1 \
--output-dir '/workspace/sunggu/1.Hemorrhage/SMART-Net/checkpoints/down_cls_test' \
--resume '/workspace/sunggu/1.Hemorrhage/SMART-Net/checkpoints/down_cls_test/epoch_0_checkpoint.pth'


## Down_SMART_Net_SEG

In [None]:
!python inference.py \
--data-folder-dir '/workspace/sunggu/1.Hemorrhage/SMART-Net/datasets/samples' \
--test-dataset-name 'Custom' \
--slice-wise-manner "False" \
--model-name 'Down_SMART_Net_SEG' \
--num-workers 4 \
--pin-mem \
--training-stream 'Downstream' \
--multi-gpu-mode 'Single' \
--cuda-visible-devices '2' \
--print-freq 1 \
--output-dir '/workspace/sunggu/1.Hemorrhage/SMART-Net/checkpoints/down_seg_test/pred_nii' \
--resume '/workspace/sunggu/1.Hemorrhage/SMART-Net/checkpoints/down_seg_test/epoch_0_checkpoint.pth'


# Log Analysis

In [None]:
import matplotlib.pyplot as plt 
import glob


def read_log(path):
    log_list = []
    lines = open(path, 'r').read().splitlines() 
    for i in range(len(lines)):
        exec('log_list.append('+lines[i] + ')')
    return  log_list

In [None]:
log_list = read_log(path = './log.txt')

In [None]:
log_list

In [None]:
train_lr   = [ log_list[i]['train_lr'] for i in range(len(log_list)) ]
train_loss = [ log_list[i]['train_loss'] for i in range(len(log_list)) ]
valid_loss = [ log_list[i]['valid_loss'] for i in range(len(log_list)) ]
valid_AUC  = [ log_list[i]['valid_AUC'] for i in range(len(log_list)) ]
valid_Acc  = [ log_list[i]['valid_Acc'] for i in range(len(log_list)) ]
valid_Sen  = [ log_list[i]['valid_Sen'] for i in range(len(log_list)) ]
valid_Spe  = [ log_list[i]['valid_Spe'] for i in range(len(log_list)) ]
epoch      = [ log_list[i]['epoch'] for i in range(len(log_list)) ]

In [None]:


plt.plot(train_loss)