In [1]:
%reload_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from source.datasets import get_BCSS,CustomDataset
from sklearn.decomposition import PCA
from source.vision_transformer import vit_small, vit4k_xs
from source.utils import update_state_dict
from utils.dbscan_utils import get_core_expert,get_metrics, get_weights,get_noisy
from utils.distance_utils import calculate_distances
from utils.patching_utils import get_patches


In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    cudnn.deterministic = True
    cudnn.benchmark = True
    print(device)

dir = '/datadisk/datasets/breast/expert'
experts = ['expert','NP20','NP6','NP8','NP18','NP16','NP4']

img = ''
patch_size = 256
region_size = 4096
mini_patch_size = 16
n_classes = 4 #1,2,3,4
json_file = 'kfolds.json'
checkpoint_256 = '/checkpoints/vit256_small_dino.pth'

imgs = get_BCSS(dir)
imgs_dataset = CustomDataset(imgs, False)

imgs_loader = torch.utils.data.DataLoader(imgs_dataset, batch_size=1, shuffle=False,num_workers=0)

cuda


In [None]:
#### SET MODELS ####
vit_patch = vit_small(
    img_size=patch_size,
    patch_size=mini_patch_size,
    embed_dim=384,
    mask_attn=False,
    num_register_tokens=0,
)


state_dict = torch.load(checkpoint_256, map_location="cpu",weights_only=False)
checkpoint_key = "teacher"
if checkpoint_key is not None and checkpoint_key in state_dict:
    state_dict = state_dict[checkpoint_key]
# remove `module.` prefix
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
# remove `backbone.` prefix induced by multicrop wrapper
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
state_dict, msg = update_state_dict(vit_patch.state_dict(), state_dict)
vit_patch.load_state_dict(state_dict, strict=False)
for name, param in vit_patch.named_parameters():
    param.requires_grad = False
vit_patch.to(device)
vit_patch.eval()

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
  (head): Identity()
)

In [7]:
#### GET FEATURES ####
features = []
for _,img in imgs_loader:
        img = img.to(device)
        feat = get_patches(img,vit_patch,patch_size,mode='features', region=False)

        features.extend(feat.cpu().detach().numpy())

In [None]:
def get_sc_expert(expert,features):
    train_data = get_BCSS(dir,expert,mode="label")
    #### GET FEATURES ####
    labels = []
    for label in train_data[0]:
            label = get_patches(torch.from_numpy(label).to(device),vit_patch,patch_size,region_size,mode="label",region=False)
            labels.append(label.cpu().detach().numpy())

    features_aux = []
    labels_aux = []
    idx_aux = []

    for i,lbl in enumerate(labels):
        feat = features[i]
        for j,label in enumerate(lbl):
                label[label==5]=10
                if len(np.unique(label))==1 and np.unique(label)[0]==10:
                    continue
                else:
                    features_aux.append(feat[j])
                    idx_aux.append(f'{i}_{j}')
                    labels_aux.append(np.unique(label))


    #### PCA ####
    pca = PCA(n_components=0.9)
    principalComponents = pca.fit_transform(features_aux)

    knn_features = []
    knn_labels = []
    knn_idx = []
    for i,lbl in enumerate(labels_aux):

        knn_features.append(principalComponents[i])
        knn_idx.append(idx_aux[i])
        knn_labels.append(lbl)
            
    ###### DIVIDE BY CLASSES #####
    cl_0 = {}
    cl_1 = {}
    cl_2 = {}
    cl_3 = {}
    cl_4 = {}

    for i, idx in enumerate(knn_idx):
        div_index = idx.split('_')
        expert_ann = labels[int(div_index[0])][int(div_index[1])]
        cl,count = np.unique(expert_ann,return_counts=True)
        cl_counts = dict(zip(cl,count))
        if 0 in cl:
            if cl_counts[0]/(256**2)>=0.65:
                cl_0[idx] = knn_features[i]
        if 1 in cl:
            if cl_counts[1]/(256**2)>=0.85:
                cl_1[idx] = knn_features[i]
        if 2 in cl:
            if cl_counts[2]/(256**2)>0.85:
                cl_2[idx] = knn_features[i]
        if 3 in cl:
            if cl_counts[3]/(256**2)>=0.85:
                cl_3[idx] = knn_features[i] 
        if 4 in cl:
            if cl_counts[4]/(256**2)>=0.85:
                cl_4[idx] = knn_features[i] 


        
        ###### DISTANCES #######
        # dist0 = calculate_distances(cl_0)
        dist1 = calculate_distances(cl_1)
        dist2 = calculate_distances(cl_2)
        dist3 = calculate_distances(cl_3)
        dist4 = calculate_distances(cl_4)
 

    ######## DBSCAN #########
    centroids,labels,eps,min_samples = get_core_expert([cl_1,cl_2,cl_3,cl_4],[dist1,dist2,dist3,dist4])
    sc,sc2,eucl_distances,min_dist,outliers,dists,means,stds,_ =  get_metrics(centroids,labels,return_neighbors=False)
    noisy = []
    for label in labels:
        try:
            noisy.append(get_noisy(label))
        except:
             noisy.append(0)
    n_centroids = [centroid.shape[0] for centroid in centroids]


    weights,percentage = get_weights(sc,outliers,dists, n_classes=n_classes)
    np.save(f'{dir}/weights/weights_{expert}_v2.npy',weights)
    
    return [len(cl_1),len(cl_2),len(cl_3),len(cl_4)],n_centroids,noisy,sc,min_dist,\
        outliers,dists,means,stds,sc2,eucl_distances,weights,percentage,eps,min_samples

In [None]:
len_cl_consensus,n_centroids_consensus,noisy_consensus,sc_consensus,min_edist_consensus,outliers_consensus,edists_consensus,\
means_consensus,stds_consensus,sc2_consensus,eucl_distance_consensus,\
    weights_consensus,percentage_consensus,eps_consensus,min_samples_consensus = get_sc_expert('expert',features)

In [None]:
len_cl_NP2,n_centroids_NP2,noisy_NP2,sc_NP2,min_edist_NP2,outliers_NP2,edists_NP2,\
means_NP2,stds_NP2,sc2_NP2,eucl_distance_NP2,\
    weights_NP2,percentage_NP2,eps_NP2,min_samples_NP2 = get_sc_expert(experts[1],features)
    
len_cl_NP6,n_centroids_NP6,noisy_NP6,sc_NP6,min_edist_NP6,outliers_NP6,edists_NP6,\
means_NP6,stds_NP6,sc2_NP6,eucl_distance_NP6,\
    weights_NP6,percentage_NP6,eps_NP6,min_samples_NP6 = get_sc_expert(experts[2],features)

len_cl_NP8,n_centroids_NP8,noisy_NP8,sc_NP8,min_edist_NP8,outliers_NP8,edists_NP8,\
means_NP8,stds_NP8,sc2_NP8,eucl_distance_NP8,\
    weights_NP8,percentage_NP8,eps_NP8,min_samples_NP8 = get_sc_expert(experts[3],features)

In [None]:
len_cl_NP14,n_centroids_NP14,noisy_NP14,sc_NP14,min_edist_NP14,outliers_NP14,edists_NP14,\
means_NP14,stds_NP14,sc2_NP14,eucl_distance_NP14,\
    weights_NP14,percentage_NP14,eps_NP14,min_samples_NP14 = get_sc_expert(experts[4],features)

len_cl_NP18,n_centroids_NP18,noisy_NP18,sc_NP18,min_edist_NP18,outliers_NP18,edists_NP18,\
means_NP18,stds_NP18,sc2_NP18,eucl_distance_NP18,\
    weights_NP18,percentage_NP18,eps_NP18,min_samples_NP18 = get_sc_expert(experts[5],features)

len_cl_NP20,n_centroids_NP20,noisy_NP20,sc_NP20,min_edist_NP20,outliers_NP20,edists_NP20,\
means_NP20,stds_NP20,sc2_NP20,eucl_distance_NP20,\
    weights_NP20,percentage_NP20,eps_NP20,min_samples_NP20 = get_sc_expert(experts[6],features)