In [2]:
import torch
import torch.backends.cudnn as cudnn
import numpy as np
from source.dataset_gleason import get_PANDA, GleasonDataset
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from sklearn.cluster import DBSCAN
from source.vision_transformer import vit_small, vit4k_xs
from source.utils import update_state_dict
import random 
import copy 
from utils.dbscan_utils import get_core_expert,get_metrics,get_nn_noisy,get_weights,get_noisy
from utils.patching_utils import get_patches
from utils.distance_utils import calculate_distances
from utils.experts_utils import create_expert

from skimage.io import imsave
import write_results


if torch.cuda.is_available():
    device = torch.device('cuda')
    cudnn.deterministic = True
    cudnn.benchmark = True
    print(device)
    
random.seed(0)
dir = '/home/laura/Documents/dataset/PANDA/'
img = ''
patch_size = 256
region_size = 4096
mini_patch_size = 16
checkpoint_256 = 'checkpoints/vit_256_small_dino_fold_4.pt'
checkpoint_4k = 'checkpoints/vit_4096_xs_dino_fold_4.pt'



cuda


In [3]:
##### DATASET #####
train_data = get_PANDA(dir)
train_dataset = GleasonDataset(train_data, False)
loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=1, shuffle=False,
    num_workers=0)


In [4]:
#### SET MODELS ####
vit_patch = vit_small(
    img_size=patch_size,
    patch_size=mini_patch_size,
    embed_dim=384,
    mask_attn=False,
    num_register_tokens=0,
)

vit_region = vit4k_xs(
    img_size=region_size,
    patch_size=patch_size,
    input_embed_dim=384,
    output_embed_dim=192,
    mask_attn=False
)

state_dict = torch.load(checkpoint_256, map_location="cpu")
checkpoint_key = "teacher"
if checkpoint_key is not None and checkpoint_key in state_dict:
    state_dict = state_dict[checkpoint_key]
# remove `module.` prefix
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
# remove `backbone.` prefix induced by multicrop wrapper
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
state_dict, msg = update_state_dict(vit_patch.state_dict(), state_dict)
vit_patch.load_state_dict(state_dict, strict=False)
for name, param in vit_patch.named_parameters():
    param.requires_grad = False
vit_patch.to(device)
vit_patch.eval()

state_dict = torch.load(checkpoint_4k, map_location="cpu")
if checkpoint_key is not None and checkpoint_key in state_dict:
    state_dict = state_dict[checkpoint_key]
# remove `module.` prefix
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
# remove `backbone.` prefix induced by multicrop wrapper
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
state_dict, msg = update_state_dict(
    vit_region.state_dict(), state_dict
)
vit_region.load_state_dict(state_dict, strict=False)
for name, param in vit_region.named_parameters():
    param.requires_grad = False
vit_region.to(device)
vit_region.eval()

  state_dict = torch.load(checkpoint_256, map_location="cpu")
  state_dict = torch.load(checkpoint_4k, map_location="cpu")


VisionTransformer4K(
  (phi): Sequential(
    (0): Linear(in_features=384, out_features=192, bias=True)
    (1): GELU(approximate='none')
    (2): Dropout(p=0.0, inplace=False)
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0-5): 6 x Block(
      (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=192, out_features=576, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=192, out_features=192, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=192, out_features=768, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=768, out_features=192, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((192,), eps

In [5]:
#### GET FEATURES ####
features = []
labels = []
for _,img,label in loader:
        img = img.to(device)
        label = label.to(device)
        feat,label = get_patches(img,label,vit_patch,vit_region,patch_size,region_size)
        features.extend(feat.cpu().detach().numpy())
        labels.append(label.cpu().detach().numpy())



features_aux = []
labels_aux = []
idx_aux = []
for i,lbl in enumerate(labels):
    feat = features[i]
    for j,label in enumerate(lbl):
            features_aux.append(feat[j])
            idx_aux.append(f'{i}_{j}')
            labels_aux.append(np.unique(label))

#### PCA ####
pca = PCA(n_components=0.9)
principalComponents = pca.fit_transform(features_aux)
explained_variance = pca.explained_variance_ratio_
total_variance = sum(list(explained_variance))*100


knn_features = []
knn_labels = []
knn_idx = []
for i,lbl in enumerate(labels_aux):
    if (2 in lbl and 3 in lbl) or (2 in lbl and 5 in lbl) or (3 in lbl and 4 in lbl) or (3 in lbl and 5 in lbl) or (4 in lbl and 5 in lbl):
        pass
    else:
        knn_features.append(principalComponents[i])
        knn_idx.append(idx_aux[i])
        if len(lbl)==1:
            knn_labels.append(lbl[0])
        else:
            knn_labels.append(lbl[-1])


###### DIVIDE BY CLASSES #####
cl_0 = {}
cl_1 = {}
cl_2 = {}
cl_3 = {}
cl_4 = {}
cl_5 = {}
for i, idx in enumerate(knn_idx):
    div_index = idx.split('_')
    expert_ann = labels[int(div_index[0])][int(div_index[1])]
    cl,count = np.unique(expert_ann,return_counts=True)
    cl_counts = dict(zip(cl,count))
    if len(cl)==1 and cl[0]==0:
        cl_0[idx] = knn_features[i]
    elif knn_labels[i] == 1 :
        if cl_counts[1]/(256**2)>0.15 and 0 not in cl:
            cl_1[idx] = knn_features[i]
    elif knn_labels[i] == 2:
        if cl_counts[2]/(256**2)>0.15 and 0 not in cl:
            cl_2[idx] = knn_features[i]
    elif knn_labels[i] == 3:
        if cl_counts[3]/(256**2)>0.15 and 0 not in cl:
            cl_3[idx] = knn_features[i] 
    elif knn_labels[i] == 4:
        if cl_counts[4]/(256**2)>0.15 and 0 not in cl:
            cl_4[idx] = knn_features[i] 
    elif knn_labels[i] == 5:
        if cl_counts[5]/(256**2)>0.15 and 0 not in cl:
            cl_5[idx] = knn_features[i] 
        

In [6]:
###### DISTANCES #######
dist0 = calculate_distances(cl_0)
dist1 = calculate_distances(cl_1)
dist2 = calculate_distances(cl_2)
dist3 = calculate_distances(cl_3)
dist4 = calculate_distances(cl_4)
dist5 = calculate_distances(cl_5)

In [7]:
##### ORIGINAL ######

centroids_orig,labels_orig,eps_orig,min_samples_orig = get_core_expert([cl_1,cl_2,cl_3,cl_4,cl_5],[dist1,dist2,dist3,dist4,dist5])
sc_orig,sc2_orig,eucl_distances_orig,min_dist_orig,outliers_orig,dists_orig,means_orig,stds_orig,neigbors_orig = get_metrics(centroids_orig,labels_orig)
noisy_orig = get_noisy(labels_orig)
nn_noisy_orig = get_nn_noisy([cl_1,cl_2,cl_3,cl_4,cl_5],centroids_orig,labels_orig)
weights_orig,percentage_orig = get_weights(sc_orig,outliers_orig,dists_orig)
# np.save('/home/laura/Documents/dataset/PANDA/expert_masks/weightsOrig_percentile.npy',weights_orig)

  return self._call_impl(*args, **kwargs)


In [8]:
###### EXPERT 1 ########
cls_exp1,new_dists_exp1 = create_expert(1,{2:[3],3:[2]},{2:cl_2,3:cl_3},labels,train_data[2],0.2,0)
centroids_exp1,labels_exp1,eps_exp1,min_samples_exp1 = get_core_expert([cl_1,cls_exp1[2],cls_exp1[3],cl_4,cl_5],[dist1,new_dists_exp1[2],new_dists_exp1[3],dist4,dist5])
sc_exp1,sc2_exp1,eucl_distances_exp1,min_dist_exp1,outliers_exp1,dists_exp1,means_exp1,stds_exp1,neigbors_exp1 = get_metrics(centroids_exp1,labels_exp1)
noisy_exp1 = get_noisy(labels_exp1)
nn_noisy_exp1 = get_nn_noisy([cl_1,cls_exp1[2],cls_exp1[3],cl_4,cl_5],centroids_exp1,labels_exp1)
weights_exp1,percentage_exp1 = get_weights(sc_exp1,outliers_exp1,dists_exp1)
# np.save('/home/laura/Documents/dataset/PANDA/expert_masks/weights1_percentile.npy',weights_exp1)

  return self._call_impl(*args, **kwargs)


In [9]:
###### EXPERT 2 ########
cls_exp2,new_dists_exp2 = create_expert(2,{4:[3],3:[4]},{3:cl_3,4:cl_4},labels,train_data[2],0.2,1)
centroids_exp2,labels_exp2,eps_exp2,min_samples_exp2 = get_core_expert([cl_1,cl_2,cls_exp2[3],cls_exp2[4],cl_5],[dist1,dist2,new_dists_exp2[3],new_dists_exp2[4],dist5])
sc_exp2,sc2_exp2,eucl_distances_exp2,min_dist_exp2,outliers_exp2,dists_exp2,means_exp2,stds_exp2,neigbors_exp2 = get_metrics(centroids_exp2,labels_exp2)
noisy_exp2 = get_noisy(labels_exp2)
nn_noisy_exp2 = get_nn_noisy([cl_1,cl_2,cls_exp2[3],cls_exp2[4],cl_5],centroids_exp2,labels_exp2)
weights_exp2,percentage_exp2 = get_weights(sc_exp2,outliers_exp2,dists_exp2)
# np.save('/home/laura/Documents/dataset/PANDA/expert_masks/weights2_percentile.npy',weights_exp2)

  return self._call_impl(*args, **kwargs)


In [1]:
###### EXPERT 3 ########
cls_exp3,new_dists_exp3 = create_expert(3,{2:[3],4:[3],3:[2,4]},{2:cl_2,3:cl_3,4:cl_4},labels,train_data[2],0.15,47)
centroids_exp3,labels_exp3,eps_exp3,min_samples_exp3 = get_core_expert([cl_1,cls_exp3[2],cls_exp3[3],cls_exp3[4],cl_5],[dist1,new_dists_exp3[2],new_dists_exp3[3],new_dists_exp3[4],dist5])
sc_exp3,sc2_exp3,eucl_distances_exp3,min_dist_exp3,outliers_exp3,dists_exp3,means_exp3,stds_exp3,neigbors_exp3 = get_metrics(centroids_exp3,labels_exp3)
noisy_exp3 = get_noisy(labels_exp3)
nn_noisy_exp3 = get_nn_noisy([cl_1,cls_exp3[2],cls_exp3[3],cls_exp3[4],cl_5],centroids_exp3,labels_exp3)
weights_exp3,percentage_exp3 = get_weights(sc_exp3,outliers_exp3,dists_exp3)
# np.save('/home/laura/Documents/dataset/PANDA/expert_masks/weights3_percentile.npy',weights_exp3)

NameError: name 'create_expert' is not defined

In [11]:
###### EXPERT 4 ########
cls_exp4,new_dists_exp4 = create_expert(4,{2:[3],3:[2]},{2:cl_2,3:cl_3},labels,train_data[2],0.25,10)
centroids_exp4,labels_exp4,eps_exp4,min_samples_exp4 = get_core_expert([cl_1,cls_exp4[2],cls_exp4[3],cl_4,cl_5],[dist1,new_dists_exp4[2],new_dists_exp4[3],dist4,dist5])
sc_exp4,sc2_exp4,eucl_distances_exp4,min_dist_exp4,outliers_exp4,dists_exp4,means_exp4,stds_exp4,neigbors_exp4 = get_metrics(centroids_exp4,labels_exp4)
nn_noisy_exp4 = get_nn_noisy([cl_1,cls_exp4[2],cls_exp4[3],cl_4,cl_5],centroids_exp4,labels_exp4)
noisy_exp4 = get_noisy(labels_exp4)
weights_exp4,percentage_exp4 = get_weights(sc_exp4,outliers_exp4,dists_exp4)
# np.save('/home/laura/Documents/dataset/PANDA/expert_masks/weights4_percentile.npy',weights_exp4)

  return self._call_impl(*args, **kwargs)


In [12]:
###### EXPERT 5 ########
cls_exp5,new_dists_exp5 = create_expert(5,{4:[3],3:[4]},{3:cl_3,4:cl_4},labels,train_data[2],0.85,0)
centroids_exp5,labels_exp5,eps_exp5,min_samples_exp5 = get_core_expert([cl_1,cl_2,cls_exp5[3],cls_exp5[4],cl_5],[dist1,dist2,new_dists_exp5[3],new_dists_exp5[4],dist5])
sc_exp5,sc2_exp5,eucl_distances_exp5,min_dist_exp5,outliers_exp5,dists_exp5,means_exp5,stds_exp5,neigbors_exp5 = get_metrics(centroids_exp5,labels_exp5)
noisy_exp5 = get_noisy(labels_exp5)
nn_noisy_exp5 = get_nn_noisy([cl_1,cl_2,cls_exp5[3],cls_exp5[4],cl_5],centroids_exp5,labels_exp5)
weights_exp5,percentage_exp5 = get_weights(sc_exp5,outliers_exp5,dists_exp5)
# np.save('/home/laura/Documents/dataset/PANDA/expert_masks/weights5_percentile.npy',weights_exp5)

  return self._call_impl(*args, **kwargs)


In [13]:
###### EXPERT 6 ########
cls_exp6,new_dists_exp6 = create_expert(6,{2:[3],4:[3],3:[2,4]},{2:cl_2,3:cl_3,4:cl_4},labels,train_data[2],0.1,5)
centroids_exp6,labels_exp6,eps_exp6,min_samples_exp6 = get_core_expert([cl_1,cls_exp6[2],cls_exp6[3],cls_exp6[4],cl_5],[dist1,new_dists_exp6[2],new_dists_exp6[3],new_dists_exp6[4],dist5])
sc_exp6,sc2_exp6,eucl_distances_exp6,min_dist_exp6,outliers_exp6,dists_exp6,means_exp6,stds_exp6,neigbors_exp6 = get_metrics(centroids_exp6,labels_exp6)
noisy_exp6 = get_noisy(labels_exp6)
nn_noisy_exp6 = get_nn_noisy([cl_1,cls_exp6[2],cls_exp6[3],cls_exp6[4],cl_5],centroids_exp6,labels_exp6)
weights_exp6,percentage_exp6 = get_weights(sc_exp6,outliers_exp6,dists_exp6)
# np.save('/home/laura/Documents/dataset/PANDA/expert_masks/weights6_percentile.npy',weights_exp6)

  return self._call_impl(*args, **kwargs)


In [14]:
eps = [eps_orig,eps_exp1,eps_exp2,eps_exp3,eps_exp4,eps_exp5,eps_exp6]
min_samples = [min_samples_orig,min_samples_exp1,min_samples_exp2,min_samples_exp3,min_samples_exp4,min_samples_exp5,min_samples_exp6]
centroids = [centroids_orig,centroids_exp1,centroids_exp2,centroids_exp3,centroids_exp4,centroids_exp5,centroids_exp6]
nn_noisy = [nn_noisy_orig,nn_noisy_exp1,nn_noisy_exp2,nn_noisy_exp3,nn_noisy_exp4,nn_noisy_exp5,nn_noisy_exp6]
noisy = [noisy_orig,noisy_exp1,noisy_exp2,noisy_exp3,noisy_exp4,noisy_exp5,noisy_exp6]
nn = [neigbors_orig,neigbors_exp1,neigbors_exp2,neigbors_exp3,neigbors_exp4,neigbors_exp5,neigbors_exp6]
outliers = [outliers_orig,outliers_exp1,outliers_exp2,outliers_exp3,outliers_exp4,outliers_exp5,outliers_exp6]
means = [means_orig,means_exp1,means_exp2,means_exp3,means_exp4,means_exp5,means_exp6]
stds = [stds_orig,stds_exp1,stds_exp2,stds_exp3,stds_exp4,stds_exp5,stds_exp6]
e_distance = [dists_orig,dists_exp1,dists_exp2,dists_exp3,dists_exp4,dists_exp5,dists_exp6]
dist_sc_v2 = [eucl_distances_orig,eucl_distances_exp1,eucl_distances_exp2,eucl_distances_exp3,eucl_distances_exp4,eucl_distances_exp5,eucl_distances_exp6]
weights = [weights_orig,weights_exp1,weights_exp2,weights_exp3,weights_exp4,weights_exp5,weights_exp6]
percentage = [percentage_orig,percentage_exp1,percentage_exp2,percentage_exp3,percentage_exp4,percentage_exp5,percentage_exp6]
sc_v2 = [sc2_orig,sc2_exp1,sc2_exp2,sc2_exp3,sc2_exp4,sc2_exp5,sc2_exp6]
sc_v1 = [sc_orig,sc_exp1,sc_exp2,sc_exp3,sc_exp4,sc_exp5,sc_exp6]
min_e_distance = [min_dist_orig,min_dist_exp1,min_dist_exp2,min_dist_exp3,min_dist_exp4,min_dist_exp5,min_dist_exp6]

cl_len_orig = [len(cl_1),len(cl_2),len(cl_3),len(cl_4),len(cl_5)]
cl_len_exp1 = [len(cl_1),len(cls_exp1[2]),len(cls_exp1[3]),len(cl_4),len(cl_5)]
cl_len_exp2 = [len(cl_1),len(cl_2),len(cls_exp2[3]),len(cls_exp2[4]),len(cl_5)]
cl_len_exp3 = [len(cl_1),len(cls_exp3[2]),len(cls_exp3[3]),len(cls_exp3[4]),len(cl_5)]
cl_len_exp4 = [len(cl_1),len(cls_exp4[2]),len(cls_exp4[3]),len(cl_4),len(cl_5)]
cl_len_exp5 = [len(cl_1),len(cl_2),len(cls_exp5[3]),len(cls_exp5[4]),len(cl_5)]
cl_len_exp6 = [len(cl_1),len(cls_exp6[2]),len(cls_exp6[3]),len(cls_exp6[4]),len(cl_5)]
cl_len = [cl_len_orig,cl_len_exp1,cl_len_exp2,cl_len_exp3,cl_len_exp4,cl_len_exp5,cl_len_exp6]

experts = ['original','exp1','exp2','exp3','exp4','exp5','exp6']


In [17]:
import write_results
cl = ['cl1','cl2','cl3','cl4','cl5']
write_results.write_cores(nn,centroids,cl_len,experts=experts)
write_results.write_noisy(nn_noisy,noisy,cl_len,experts=experts)
write_results.write_outliers(outliers,experts=experts)
write_results.write_euclidean_distances(means,stds,experts=experts)
write_results.write_edistances(e_distance,experts=experts)
write_results.write_min_euclidean_distance(dist_sc_v2,experts=experts)
write_results.write_min_edistance(min_e_distance,experts=experts)
write_results.write_sc1(sc_v1,experts=experts)
write_results.write_sc2(sc_v2,experts=experts)
write_results.write_weights(weights,percentage,experts=experts)
write_results.write_parameters(cl_len,eps,min_samples,experts=experts)