In [None]:
import torch
import torch.backends.cudnn as cudnn
import numpy as np
from source.datasets import get_PANDA, CustomDataset
from sklearn.decomposition import PCA
from source.vision_transformer import vit_small, vit4k_xs
from source.utils import update_state_dict
import random 
from utils.dbscan_utils import get_core_expert,get_core_noncore_points
from utils.distance_utils import calculate_distances
from utils.patching_utils import get_patches
from SC_maps_utils import get_silhouette_complete,get_outliers

if torch.cuda.is_available():
    device = torch.device('cuda')
    cudnn.deterministic = True
    cudnn.benchmark = True
    print(device)
    
random.seed(0)
dir = '/datadisk/datasets/PANDA/'
img = ''
expert = 'original'
patch_size = 256
region_size = 4096
mini_patch_size = 16
checkpoint_256 = '/checkpoints/vit_256_small_dino_fold_4.pt'
checkpoint_4k = '/checkpoints/vit_4096_xs_dino_fold_4.pt'

cuda


In [None]:
##### DATASET #####
train_data = get_PANDA(dir,expert=expert)
train_dataset = CustomDataset(train_data)
loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=1, shuffle=False,
    num_workers=0)



In [5]:
#### SET MODELS ####
vit_patch = vit_small(
    img_size=patch_size,
    patch_size=mini_patch_size,
    embed_dim=384,
    mask_attn=False,
    num_register_tokens=0,
)

vit_region = vit4k_xs(
    img_size=region_size,
    patch_size=patch_size,
    input_embed_dim=384,
    output_embed_dim=192,
    mask_attn=False
)

state_dict = torch.load(checkpoint_256, map_location="cpu",weights_only=False)
checkpoint_key = "teacher"
if checkpoint_key is not None and checkpoint_key in state_dict:
    state_dict = state_dict[checkpoint_key]
# remove `module.` prefix
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
# remove `backbone.` prefix induced by multicrop wrapper
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
state_dict, msg = update_state_dict(vit_patch.state_dict(), state_dict)
vit_patch.load_state_dict(state_dict, strict=False)
for name, param in vit_patch.named_parameters():
    param.requires_grad = False
vit_patch.to(device)
vit_patch.eval()

state_dict = torch.load(checkpoint_4k, map_location="cpu",weights_only=False)
if checkpoint_key is not None and checkpoint_key in state_dict:
    state_dict = state_dict[checkpoint_key]
# remove `module.` prefix
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
# remove `backbone.` prefix induced by multicrop wrapper
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
state_dict, msg = update_state_dict(
    vit_region.state_dict(), state_dict
)
vit_region.load_state_dict(state_dict, strict=False)
for name, param in vit_region.named_parameters():
    param.requires_grad = False
vit_region.to(device)
vit_region.eval()

VisionTransformer4K(
  (phi): Sequential(
    (0): Linear(in_features=384, out_features=192, bias=True)
    (1): GELU(approximate='none')
    (2): Dropout(p=0.0, inplace=False)
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0-5): 6 x Block(
      (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=192, out_features=576, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=192, out_features=192, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=192, out_features=768, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=768, out_features=192, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((192,), eps

In [None]:
#### GET FEATURES ####
features = []
labels = []
for _,img,label in loader:
        img = img.to(device)
        label = label.to(device)
        feat,label = get_patches(img,vit_patch,vit_region,patch_size,region_size,y=label,mode='both',region=True)
        features.extend(feat.cpu().detach().numpy())
        labels.append(label.cpu().detach().numpy())



features_aux = []
labels_aux = []
idx_aux = []
for i,lbl in enumerate(labels):
    feat = features[i]
    for j,label in enumerate(lbl):
            features_aux.append(feat[j])
            idx_aux.append(f'{i}_{j}')
            labels_aux.append(np.unique(label))

#### PCA ####
pca = PCA(n_components=0.9)
principalComponents = pca.fit_transform(features_aux)
explained_variance = pca.explained_variance_ratio_
total_variance = sum(list(explained_variance))*100


knn_features = []
knn_labels = []
knn_idx = []
for i,lbl in enumerate(labels_aux):
    if (2 in lbl and 3 in lbl) or (2 in lbl and 5 in lbl) or (3 in lbl and 4 in lbl) or (3 in lbl and 5 in lbl) or (4 in lbl and 5 in lbl):
        pass
    else:
        knn_features.append(principalComponents[i])
        knn_idx.append(idx_aux[i])
        if len(lbl)==1:
            knn_labels.append(lbl[0])
        else:
            knn_labels.append(lbl[-1])


###### DIVIDE BY CLASSES #####
cl_0 = {}
cl_1 = {}
cl_2 = {}
cl_3 = {}
cl_4 = {}
cl_5 = {}
for i, idx in enumerate(knn_idx):
    div_index = idx.split('_')
    expert_ann = labels[int(div_index[0])][int(div_index[1])]
    cl,count = np.unique(expert_ann,return_counts=True)
    cl_counts = dict(zip(cl,count))
    if len(cl)==1 and cl[0]==0:
        cl_0[idx] = knn_features[i]
    elif knn_labels[i] == 1 :
        if cl_counts[1]/(256**2)>0.15 and 0 not in cl:
            cl_1[idx] = knn_features[i]
    elif knn_labels[i] == 2:
        if cl_counts[2]/(256**2)>0.15 and 0 not in cl:
            cl_2[idx] = knn_features[i]
    elif knn_labels[i] == 3:
        if cl_counts[3]/(256**2)>0.15 and 0 not in cl:
            cl_3[idx] = knn_features[i] 
    elif knn_labels[i] == 4:
        if cl_counts[4]/(256**2)>0.15 and 0 not in cl:
            cl_4[idx] = knn_features[i] 
    elif knn_labels[i] == 5:
        if cl_counts[5]/(256**2)>0.15 and 0 not in cl:
            cl_5[idx] = knn_features[i] 
        

In [7]:
###### DISTANCES #######
dist0 = calculate_distances(cl_0)
dist1 = calculate_distances(cl_1)
dist2 = calculate_distances(cl_2)
dist3 = calculate_distances(cl_3)
dist4 = calculate_distances(cl_4)
dist5 = calculate_distances(cl_5)

In [8]:
centroids_orig,labels_orig,eps_orig,min_samples_orig = get_core_expert([cl_1,cl_2,cl_3,cl_4,cl_5],[dist1,dist2,dist3,dist4,dist5])
outliers1 = get_outliers(labels_orig[0],cl_1)
outliers2 = get_outliers(labels_orig[1],cl_2)
outliers3 = get_outliers(labels_orig[2],cl_3)
outliers4 = get_outliers(labels_orig[3],cl_4)
outliers5 = get_outliers(labels_orig[4],cl_5)

In [None]:
core1,noncore1 = get_core_noncore_points(cl_1, centroids_orig[0],labels_orig[0])
core2,noncore2 = get_core_noncore_points(cl_2, centroids_orig[1],labels_orig[1])
core3,noncore3 = get_core_noncore_points(cl_3, centroids_orig[2],labels_orig[2])
core4,noncore4 = get_core_noncore_points(cl_4, centroids_orig[3],labels_orig[3])
core5,noncore5 = get_core_noncore_points(cl_5, centroids_orig[4],labels_orig[4])

cores = [core1,core2,core3,core4,core5]
s_core = get_silhouette_complete(cores)

In [23]:
sil1 = [0.1948883852802351,
 0.04331368410891981,
 0.09613475633285883,
 0.2553953597061623,
 0.1333708716583793]
outliers1 = 752/7063
outliers2 = 28/706
outliers3 = 92/151
outliers4 = 216/1309
outliers5 = 8/42
outliers = [outliers1,outliers2,outliers3,outliers4,outliers5]

sc1 = [sil*(1-outlier) for sil,outlier in zip(sil1,outliers)]
sc1

[0.17413855295250796,
 0.04159586094312696,
 0.03756258691151437,
 0.2132521987462455,
 0.10796689610440229]

In [26]:
sil2 = [0.19542892965896444,
 0.09727598751633999,
 0.017125044101363018,
 0.16556768169160382,
 0.14633140057743327]

outliers1 = 752/7063
outliers2 = 44/792
outliers3 = 42/552
outliers4 = 215/1120
outliers5 = 8/42
outliers = [outliers1,outliers2,outliers3,outliers4,outliers5]

sc2 = [sil*(1-outlier) for sil,outlier in zip(sil2,outliers)]
sc2

[0.1746215453883229,
 0.09187176598765444,
 0.015822051615389743,
 0.13378459993830488,
 0.11845875284839837]

In [27]:
sil3 = [0.19418097380697355,
 0.04553016541330985,
 0.06461245162781215,
 0.20946730372073477,
 0.14155606143216817]

outliers1 = 752/7063
outliers2 = 20/728
outliers3 = 92/569
outliers4 = 169/1167
outliers5 = 8/42
outliers = [outliers1,outliers2,outliers3,outliers4,outliers5]

sc3 = [sil*(1-outlier) for sil,outlier in zip(sil3,outliers)]
sc3

[0.17350645981818066,
 0.04427933669316397,
 0.054165447146689624,
 0.1791331354869694,
 0.11459300211175519]

In [28]:
silOrig = [0.19264332758115896,
 0.08851190871539659,
 0.13249770354336182,
 0.2543056422924208,
 0.13478291233762052
]

outliers1 = 752/7063
outliers2 = 44/792
outliers3 = 64/363
outliers4 = 216/1309
outliers5 = 8/42
outliers = [outliers1,outliers2,outliers3,outliers4,outliers5]

scOrig = [sil*(1-outlier) for sil,outlier in zip(silOrig,outliers)]
scOrig

[0.17213252730634207,
 0.08359458045343011,
 0.10913722688557902,
 0.21234229719298392,
 0.10910997665426424]