In [None]:
import os
import numpy as np
import pandas as pd
from sklearn import metrics
from collections import OrderedDict

import torch

In [None]:
from datasets.dataset_generic import Generic_MIL_Dataset
from utils.utils import get_split_loader, print_network
from models.model_clam import CLAM_SB

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [1]:
data = '/path/to/data'
models = '/path/to/model/'
labels = '/path/to/label'

In [None]:
dataset = Generic_MIL_Dataset(csv_path = 'dataset_csv/training_data_full.csv',
                            data_dir = data,
                            shuffle = False,
                            seed = 1,
                            print_info = True,
                            label_dict = {'G':0, 'O':1, 'A':2},
                            patient_strat = False,
                            ignore = [])

In [None]:
model_dict = {'n_classes':3, 'size_arg':'small', 'k_sample':25}
instance_loss_fn = torch.nn.CrossEntropyLoss()
model = CLAM_SB(**model_dict, instance_loss_fn=instance_loss_fn)

In [None]:
device_ids = list(range(torch.cuda.device_count()))

model.attention_net = torch.nn.DataParallel(model.attention_net, device_ids=device_ids).to('cuda:0')
model.classifiers = model.classifiers.to(device)
model.instance_classifiers = model.instance_classifiers.to(device)

In [None]:
fold = 1

# Loader
train_dataset, val_dataset, test_dataset = dataset.return_splits(from_id=False, 
                                    csv_path='/path/to/labels')
train_loader = get_split_loader(train_dataset)
val_loader = get_split_loader(val_dataset)
test_loader = get_split_loader(test_dataset)

loader = train_loader

In [None]:
a = torch.load(os.path.join(models, 's_{}_checkpoint.pt'.format(fold)))

In [None]:
new_a = OrderedDict()

for k, v in a.items():
    if 'module.3.attention' in k:
        k = k.replace('module.3', 'module.2')
    new_a[k] = v
    
model.load_state_dict(new_a)

In [None]:
slide_ids = loader.dataset.slide_data['slide_id']

gts = []
preds = []
feats = []

model.eval()
for batch_idx, (data, label) in enumerate(loader):
    
    slide_id = slide_ids.iloc[batch_idx]
    print(slide_id)
    
    data, label = data.to(device), label.to(device)
    with torch.no_grad():
        logits, Y_prob, Y_hat, A, results_dict = model.forward(data, return_features=True)
        preds.append(Y_hat.item())
        gts.append(label.item())
        feats.append(results_dict['features'].cpu().numpy().squeeze())

In [None]:
savedir = '/path/to/results'

feature_array = np.array(feats)
np.save(os.path.join(savedir, datasplit+'_'+str(fold)+'.npy'), feature_array)