# Importing libraries

In [1]:
import os
import json
import pickle
import numpy as np
from tqdm import tqdm
from easydict import EasyDict as edict

from utils.train import *

import warnings
warnings.filterwarnings("ignore")

%config Completer.use_jedi = False

# Specifying trained model

In [2]:
EXP_NAME = '3fcaf56e-b2bf-4d27-bd6d-896febd504e9'

with open(f'experiments/{EXP_NAME}/config.json') as json_file:
    config = json.load(json_file)

config = edict(config)
config.SIZE = (241, 336, 283)

os.makedirs(f'experiments/{EXP_NAME}/predictions_', exist_ok=True)

In [3]:
all_data_dict = {}
for feature in config.FEATURES:
    all_data_dict[feature] = [f"{config.path_to_data}/{feature}/{subject}.nii" for subject in config.subjects]
all_data_dict['brains'] = [f"{config.path_to_data}/{config.BRAIN_MODALITY}_brains/{subject}.nii" for subject in config.subjects]
all_data_dict['labels'] = [f"{config.path_to_data}/labels/{subject}.nii" for subject in config.subjects]

In [4]:
import json
import torch
import numpy as np 
from tqdm import tqdm 

from utils.crop import BrainDataSegCrop
from utils.model_etc import build_multi_part_segmentation

def preprocess_data_for_model(point_cloud, mask_of_points_repetition):
    """
    Function preprocess point cloud and mask of points repetion for model, also create feature tesnor from point cloud tensor
    
    :params point_cloud:
    :params mask_of_points_repetition:
    
    :outputs point_cloud:
    :outputs mask_of_points_repetition:
    :outputs features:
    """
    
    point_cloud = point_cloud.unsqueeze(0)
    mask_of_points_repetition = mask_of_points_repetition.unsqueeze(0)
    features = point_cloud.transpose(1, 2).contiguous()
    
    features = features.cuda(non_blocking=True)
    point_cloud = point_cloud[:, :, :3].cuda(non_blocking=True)
    mask_of_points_repetition = mask_of_points_repetition.cuda(non_blocking=True)
    
    return point_cloud, mask_of_points_repetition, features

def inference(all_data_dict, weights_path, config):
    """
    Function 
    
    :params
    :outputs
    """
    
    model, _ = build_multi_part_segmentation(config=config)
    model.load_state_dict(torch.load(weights_path))
    model.eval()
    model.cuda()
        
    for idx in tqdm(range(len(all_data_dict['brains']))):
        
        subject = all_data_dict['brains'][idx].split('/')[-1].split('.')[0]
        one_brain_dict = {key: all_data_dict[key][idx] for key in all_data_dict}
        one_brain_crops_data_loader = BrainDataSegCrop(config=config, task='test', data_dict=one_brain_dict, return_air_mask_test=True)

        original_point_cloud_flattened = []
        labels_flattened = []
        softmax_predictions_flattened = []
        air_masks_flattened = []

        for crop in one_brain_crops_data_loader:

            original_point_cloud, mask_of_points_repetition, labels, air_mask = [crop[key] for key in ["current_points", "mask", "current_points_labels", "air_mask"]]

            point_cloud, mask_of_points_repetition, features = preprocess_data_for_model(original_point_cloud, mask_of_points_repetition)
            
            with torch.no_grad():
                predictions = model(point_cloud, mask_of_points_repetition, features)

            softmax_predictions_flattened += torch.softmax(predictions[0], dim=1)[0, 1, :].reshape(-1).detach().cpu().numpy().tolist()
            labels_flattened += labels.reshape(-1).detach().cpu().numpy().tolist()
            air_masks_flattened += air_mask.reshape(-1).tolist()

            if config.IS_RETURN_ABS_COORDS:
                means = np.array([x // 2 for x in config.SIZE])
                half_range = np.array([x // 2 for x in config.SIZE])
                original_point_cloud_flattened += (original_point_cloud[:, :3].detach().cpu().numpy() * half_range + means).tolist()
            else:
                original_point_cloud_flattened += (original_point_cloud[:, :3].detach().cpu().numpy()).tolist()


        resulting_dict = {'coordinates': original_point_cloud_flattened,
                          'predictions': softmax_predictions_flattened,
                          'labels': labels_flattened,
                          'air_maks': air_masks_flattened}
        
        with open(f'experiments/{config.EXP_NAME}/predictions_/{subject}.json', 'w') as file:
                json.dump(resulting_dict, file)

In [5]:
DEVICE = 1
torch.cuda.set_device(f"cuda:{DEVICE}")

In [6]:
weights_path = 'experiments/c3adc367-329a-4189-8fd0-69d97ec26f1f/weights/1_fold.pth'
inference(all_data_dict=all_data_dict, weights_path=weights_path, config=config)

100%|██████████| 19/19 [07:40<00:00, 24.25s/it]
