In [None]:
import os
import numpy as np
import pandas as pd
from sklearn import metrics

import torch
import torch.nn as nn

import monai
from monai.networks.nets import DenseNet121

from utilities import metrics_all
from data_feature_extraction import data_create_feature_extraction

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [1]:
data = '/path/to/data/folder'
models = '/path/to/model/folder'
labels = '/path/to/label/folder'

In [None]:
# Modality ['t1','t2', 't1ce', 'flair']
modality = 'flair'

# Fold 
fold = 1

# Loader 
train_loader, val_loader, test_loader = data_create_feature_extraction(data_root=data, label_root=labels, batch_size=1, fold=fold, modality=modality)
loader = train_loader

In [None]:
modality_path = os.path.join(models, modality)

In [None]:
model = DenseNet121(spatial_dims=3, in_channels=1, out_channels=3)

In [None]:
model.load_state_dict(torch.load(os.path.join(modality_path, 'fold_{}_checkpoint.pt'.format(fold))))
model.to(device)

In [None]:
# setting the hooks for feature extraction

def get_features(name):
    def hook(model, input, output):
        features[name] = output.detach()
    return hook

In [None]:
model.class_layers.flatten.register_forward_hook(get_features('features'))

In [None]:
# Feature extraction

features = {}
preds = []
feats = []

model.eval()

for m in model.modules():
    if isinstance(m, nn.BatchNorm3d):
        m.track_running_stats=False

with torch.no_grad():
    for i, data in enumerate(loader):
        images_test, labels_test, radpath_id = data['image'], data['gt'], data['radpath_ID']
        images_test, labels_test = images_test.to(device), labels_test.to(device)
        pred_test = model(images_test)
        prediction = torch.argmax(pred_test, 1)
        preds.append(prediction.item())
        feats.append(features['features'].cpu().numpy().squeeze())
        print(i, radpath_id)

In [None]:
savedir = '/path/to/result/folder'
a = np.array(feats)
np.save(os.path.join(savedir, datasplit+'_'+str(fold)+'_'+modality+'.npy'), a)