In [None]:
import torch
import torch.backends.cudnn as cudnn
import numpy as np
from source.datasets import get_Gleason, CustomDataset
from sklearn.decomposition import PCA
from source.vision_transformer import vit_small, vit4k_xs
from source.utils import update_state_dict
from utils.dbscan_utils import get_core_expert,get_metrics, get_weights,get_noisy
from utils.distance_utils import calculate_distances
from utils.patching_utils import get_patches

if torch.cuda.is_available():
    device = torch.device('cuda')
    cudnn.deterministic = True
    cudnn.benchmark = True
    print(device)

dir = '/datadisk/datasets/Gleason19/'
img = ''
patch_size = 256
region_size = 4096
mini_patch_size = 16
n_classes = 4
expert = 1
json_file = 'kfolds.json'
checkpoint_256 = '/checkpoints/vit_256_small_dino_fold_4.pt'
checkpoint_4k = '/checkpoints/vit_4096_xs_dino_fold_4.pt'

imgs = get_Gleason(dir)
imgs_dataset = CustomDataset(imgs, False)

imgs_loader = torch.utils.data.DataLoader(imgs_dataset, batch_size=1, shuffle=False,num_workers=0)

cuda


In [3]:
#### SET MODELS ####
vit_patch = vit_small(
    img_size=patch_size,
    patch_size=mini_patch_size,
    embed_dim=384,
    mask_attn=False,
    num_register_tokens=0,
)

vit_region = vit4k_xs(
    img_size=region_size,
    patch_size=patch_size,
    input_embed_dim=384,
    output_embed_dim=192,
    mask_attn=False
)

state_dict = torch.load(checkpoint_256, map_location="cpu",weights_only=False)
checkpoint_key = "teacher"
if checkpoint_key is not None and checkpoint_key in state_dict:
    state_dict = state_dict[checkpoint_key]
# remove `module.` prefix
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
# remove `backbone.` prefix induced by multicrop wrapper
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
state_dict, msg = update_state_dict(vit_patch.state_dict(), state_dict)
vit_patch.load_state_dict(state_dict, strict=False)
for name, param in vit_patch.named_parameters():
    param.requires_grad = False
vit_patch.to(device)
vit_patch.eval()

state_dict = torch.load(checkpoint_4k, map_location="cpu",weights_only=False)
if checkpoint_key is not None and checkpoint_key in state_dict:
    state_dict = state_dict[checkpoint_key]
# remove `module.` prefix
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
# remove `backbone.` prefix induced by multicrop wrapper
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
state_dict, msg = update_state_dict(
    vit_region.state_dict(), state_dict
)
vit_region.load_state_dict(state_dict, strict=False)
for name, param in vit_region.named_parameters():
    param.requires_grad = False
vit_region.to(device)
vit_region.eval()

VisionTransformer4K(
  (phi): Sequential(
    (0): Linear(in_features=384, out_features=192, bias=True)
    (1): GELU(approximate='none')
    (2): Dropout(p=0.0, inplace=False)
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0-5): 6 x Block(
      (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=192, out_features=576, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=192, out_features=192, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=192, out_features=768, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=768, out_features=192, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((192,), eps

In [5]:
#### GET FEATURES ####
features = []

for img in imgs_loader:
        img = img[1].to(device)
        
        feat = get_patches(img,vit_patch,vit_region,patch_size,region_size)
        features.extend(feat.cpu().detach().numpy())
        

In [None]:
def get_sc_expert(expert,features):
    #### DATA LOADER ####  
    train_data = get_Gleason(dir,expert=expert,mode="label")
    #### GET FEATURES ####
    labels = []
    for label in train_data[0]:
            label = get_patches(torch.from_numpy(label).to(device),vit_patch,vit_region,patch_size,region_size,mode="label")
            labels.append(label.cpu().detach().numpy())


    features_aux = []
    labels_aux = []
    idx_aux = []

    for i,lbl in enumerate(labels):
        feat = features[i]
        for j,label in enumerate(lbl):
                features_aux.append(feat[j])
                idx_aux.append(f'{i}_{j}')
                labels_aux.append(np.unique(label))

    #### PCA ####
    pca = PCA(n_components=0.9)
    principalComponents = pca.fit_transform(features_aux)


    knn_features = []
    knn_labels = []
    knn_idx = []
    for i,lbl in enumerate(labels_aux):
        if len(lbl)==1:
            if lbl[0]!=-1:
                knn_features.append(principalComponents[i])
                knn_idx.append(idx_aux[i])
                knn_labels.append(lbl)
        else:
            if -1 not in lbl:
                knn_features.append(principalComponents[i])
                knn_idx.append(idx_aux[i])
                knn_labels.append(lbl)
            
    ###### DIVIDE BY CLASSES #####
    cl_1 = {}
    cl_3 = {}
    cl_4 = {}
    cl_5 = {}


    for i, idx in enumerate(knn_idx):
        div_index = idx.split('_')
        expert_ann = labels[int(div_index[0])][int(div_index[1])]
        cl,count = np.unique(expert_ann,return_counts=True)
        cl_counts = dict(zip(cl,count))

        if 1 in knn_labels[i]:
            if cl_counts[1]/(256**2)>0.8:
                cl_1[idx] = knn_features[i]
        if 2 in knn_labels[i]:
            if cl_counts[2]/(256**2)>0.8:
                cl_1[idx] = knn_features[i]
        if 3 in knn_labels[i]:
            if cl_counts[3]/(256**2)>0.8:
                cl_3[idx] = knn_features[i] 
        if 4 in knn_labels[i]:
            if cl_counts[4]/(256**2)>0.8:
                cl_4[idx] = knn_features[i] 
        if 5 in knn_labels[i]:
            if cl_counts[5]/(256**2)>0.8:
                cl_5[idx] = knn_features[i] 
    ###### DISTANCES #######

    dist1 = calculate_distances(cl_1)
    dist3 = calculate_distances(cl_3)
    dist4 = calculate_distances(cl_4)
    dist5 = calculate_distances(cl_5)


    ######## DBSCAN #########

    centroids,labels,eps,min_samples = get_core_expert([cl_1,cl_3,cl_4,cl_5],[dist1,dist3,dist4,dist5])
    sc,sc2,eucl_distances,min_dist,outliers,dists,means,stds,_ =  get_metrics(centroids,labels,return_neighbors=False)

    noisy = []
    for label in labels:
        noisy.append(get_noisy(label))
    n_centroids = [centroid.shape[0] for centroid in centroids]


    weights,percentage = get_weights(sc,outliers,dists,n_classes=n_classes)
    # np.save(f'{dir}/weights/weights{expert}_original.npy',weights)
    
    return [len(cl_1),len(cl_3),len(cl_4),len(cl_5)],n_centroids,noisy,sc,min_dist,\
        outliers,dists,means,stds,weights,percentage,eps,min_samples
        

In [None]:
len_cl1,n_centroids1,noisy1,sc_exp1,min_edist_exp1,outliers_exp1,edists_exp1,\
means_exp1,stds_exp1,sc2_exp1,eucl_distance_exp1,\
    weights_exp1,percentage_exp1,eps1,min_samples1 = get_sc_expert(1,features)

In [None]:
len_cl1,n_centroids1,noisy1,sc_exp1,min_edist_exp1,outliers_exp1,edists_exp1,\
means_exp1,stds_exp1,sc2_exp1,eucl_distance_exp1,\
    weights_exp1,percentage_exp1,eps1,min_samples1 = get_sc_expert(1,features)
    
len_cl2,n_centroids2,noisy2,sc_exp2,min_edist_exp2,outliers_exp2,edists_exp2,\
means_exp2,stds_exp2,\
    weights_exp2,percentage_exp2,eps2,min_samples2 = get_sc_expert(2,features) 


In [9]:
len_cl3,n_centroids3,noisy3,sc_exp3,min_edist_exp3,outliers_exp3,edists_exp3,\
means_exp3,stds_exp3,sc2_exp3,eucl_distance_exp3,\
    weights_exp3,percentage_exp3,eps3,min_samples3 = get_sc_expert(3,features)

  return self._call_impl(*args, **kwargs)


In [None]:
len_cl4,n_centroids4,noisy4,sc_exp4,min_edist_exp4,outliers_exp4,edists_exp4,\
means_exp4,stds_exp4,sc2_exp4,eucl_distance_exp4,\
    weights_exp4,percentage_exp4,eps4,min_samples4 = get_sc_expert(4,features)  
    
    
len_cl5,n_centroids5,noisy5,sc_exp5,min_edist_exp5,outliers_exp5,edists_exp5,\
means_exp5,stds_exp5,sc2_exp5,eucl_distance_exp5,\
    weights_exp5,percentage_exp5,eps5,min_samples5 = get_sc_expert(5,features)
    
len_cl6,n_centroids6,noisy6,sc_exp6,min_edist_exp6,outliers_exp6,edists_exp6,\
means_exp6,stds_exp6,sc2_exp6,eucl_distance_exp6,\
    weights_exp6,percentage_exp6,eps6,min_samples6 = get_sc_expert(6,features)