In [None]:
import torch
import torch.backends.cudnn as cudnn
from source.datasets import get_PANDA, CustomDataset
from source.vision_transformer import vit_small, vit4k_xs
from source.utils import update_state_dict
import random
from sklearn.decomposition import PCA
import copy 
import numpy as np
from utils.dbscan_utils import get_metrics
from utils.patching_utils import get_patches
from utils.experts_utils import create_expert


if torch.cuda.is_available():
    device = torch.device('cuda')
    cudnn.deterministic = True
    cudnn.benchmark = True
    print(device)
    
random.seed(0)
dir = '/home/laura/Documents/dataset/PANDA/'
img = ''
patch_size = 256
region_size = 4096
mini_patch_size = 16
n_classes = 5
checkpoint_256 = 'checkpoints/vit_256_small_dino_fold_4.pt'
checkpoint_4k = 'checkpoints/vit_4096_xs_dino_fold_4.pt'



cuda


In [None]:
##### DATASET #####
train_data = get_PANDA(dir)
train_dataset = CustomDataset(train_data)
loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=1, shuffle=False,
    num_workers=0)


In [3]:
#### SET MODELS ####
vit_patch = vit_small(
    img_size=patch_size,
    patch_size=mini_patch_size,
    embed_dim=384,
    mask_attn=False,
    num_register_tokens=0,
)

vit_region = vit4k_xs(
    img_size=region_size,
    patch_size=patch_size,
    input_embed_dim=384,
    output_embed_dim=192,
    mask_attn=False
)

state_dict = torch.load(checkpoint_256, map_location="cpu")
checkpoint_key = "teacher"
if checkpoint_key is not None and checkpoint_key in state_dict:
    state_dict = state_dict[checkpoint_key]
# remove `module.` prefix
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
# remove `backbone.` prefix induced by multicrop wrapper
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
state_dict, msg = update_state_dict(vit_patch.state_dict(), state_dict)
vit_patch.load_state_dict(state_dict, strict=False)
for name, param in vit_patch.named_parameters():
    param.requires_grad = False
vit_patch.to(device)
vit_patch.eval()

state_dict = torch.load(checkpoint_4k, map_location="cpu")
if checkpoint_key is not None and checkpoint_key in state_dict:
    state_dict = state_dict[checkpoint_key]
# remove `module.` prefix
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
# remove `backbone.` prefix induced by multicrop wrapper
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
state_dict, msg = update_state_dict(
    vit_region.state_dict(), state_dict
)
vit_region.load_state_dict(state_dict, strict=False)
for name, param in vit_region.named_parameters():
    param.requires_grad = False
vit_region.to(device)
vit_region.eval()

  state_dict = torch.load(checkpoint_256, map_location="cpu")
  state_dict = torch.load(checkpoint_4k, map_location="cpu")


VisionTransformer4K(
  (phi): Sequential(
    (0): Linear(in_features=384, out_features=192, bias=True)
    (1): GELU(approximate='none')
    (2): Dropout(p=0.0, inplace=False)
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0-5): 6 x Block(
      (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=192, out_features=576, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=192, out_features=192, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=192, out_features=768, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=768, out_features=192, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((192,), eps

In [None]:
#### GET FEATURES ####
features = []
labels = []
for _,img,label in loader:
        img = img.to(device)
        label = label.to(device)
        feat,label = get_patches(img,vit_patch,vit_region,patch_size,region_size,y=label,mode='both',region=True)
        features.extend(feat.cpu().detach().numpy())
        labels.append(label.cpu().detach().numpy())



features_aux = []
labels_aux = []
idx_aux = []
for i,lbl in enumerate(labels):
    feat = features[i]
    for j,label in enumerate(lbl):
            features_aux.append(feat[j])
            idx_aux.append(f'{i}_{j}')
            labels_aux.append(np.unique(label))

#### PCA ####
pca = PCA(n_components=0.9)
principalComponents = pca.fit_transform(features_aux)
explained_variance = pca.explained_variance_ratio_
total_variance = sum(list(explained_variance))*100


knn_features = []
knn_labels = []
knn_idx = []
for i,lbl in enumerate(labels_aux):
    if (2 in lbl and 3 in lbl) or (2 in lbl and 5 in lbl) or (3 in lbl and 4 in lbl) or (3 in lbl and 5 in lbl) or (4 in lbl and 5 in lbl):
        pass
    else:
        knn_features.append(principalComponents[i])
        knn_idx.append(idx_aux[i])
        if len(lbl)==1:
            knn_labels.append(lbl[0])
        else:
            knn_labels.append(lbl[-1])


###### DIVIDE BY CLASSES #####
cl_0 = {}
cl_1 = {}
cl_2 = {}
cl_3 = {}
cl_4 = {}
cl_5 = {}
for i, idx in enumerate(knn_idx):
    div_index = idx.split('_')
    expert_ann = labels[int(div_index[0])][int(div_index[1])]
    cl,count = np.unique(expert_ann,return_counts=True)
    cl_counts = dict(zip(cl,count))
    if len(cl)==1 and cl[0]==0:
        cl_0[idx] = knn_features[i]
    elif knn_labels[i] == 1 :
        if cl_counts[1]/(256**2)>0.15 and 0 not in cl:
            cl_1[idx] = knn_features[i]
    elif knn_labels[i] == 2:
        if cl_counts[2]/(256**2)>0.15 and 0 not in cl:
            cl_2[idx] = knn_features[i]
    elif knn_labels[i] == 3:
        if cl_counts[3]/(256**2)>0.15 and 0 not in cl:
            cl_3[idx] = knn_features[i] 
    elif knn_labels[i] == 4:
        if cl_counts[4]/(256**2)>0.15 and 0 not in cl:
            cl_4[idx] = knn_features[i] 
    elif knn_labels[i] == 5:
        if cl_counts[5]/(256**2)>0.15 and 0 not in cl:
            cl_5[idx] = knn_features[i] 
        

In [None]:
##### ORIGINAL ######
sc_orig,sc2_orig,eucl_distances_orig,min_dist_orig,dists_orig,means_orig,stds_orig = get_metrics([cl_1,cl_2,cl_3,cl_4,cl_5])
weights_orig,percentage_orig = get_weights(sc_orig,[0,0,0,0,0],dists_orig, n_classes=n_classes,dbscan=False)

  return self._call_impl(*args, **kwargs)


In [None]:
###### EXPERT 1 ########
cls_exp1,new_dists_exp1 = create_expert(1,{2:[3],3:[2]},{2:cl_2,3:cl_3},labels,train_data[2],0.2,0)
sc_exp1,sc2_exp1,eucl_distances_exp1,min_dist_exp1,dists_exp1,means_exp1,stds_exp1 = get_metrics([cl_1,cls_exp1[2],cls_exp1[3],cl_4,cl_5])
weights_exp1,percentage_exp1 = get_weights(sc_exp1,[0,0,0,0,0],dists_exp1, n_classes=n_classes,dbscan=False)

  return self._call_impl(*args, **kwargs)


In [None]:
###### EXPERT 2 ########
cls_exp2,new_dists_exp2 = create_expert(2,{4:[3],3:[4]},{3:cl_3,4:cl_4},labels,train_data[2],0.2,1)
sc_exp2,sc2_exp2,eucl_distances_exp2,min_dist_exp2,dists_exp2,means_exp2,stds_exp2 = get_metrics([cl_1,cl_2,cls_exp2[3],cls_exp2[4],cl_5])
weights_exp2,percentage_exp2 = get_weights(sc_exp2,[0,0,0,0,0],dists_exp2, n_classes=n_classes,dbscan=False)

  return self._call_impl(*args, **kwargs)


In [None]:
###### EXPERT 3 ########
cls_exp3,new_dists_exp3 = create_expert(3,{2:[3],4:[3],3:[2,4]},{2:cl_2,3:cl_3,4:cl_4},labels,train_data[2],0.15,47)
sc_exp3,sc2_exp3,eucl_distances_exp3,min_dist_exp3,dists_exp3,means_exp3,stds_exp3 = get_metrics([cl_1,cls_exp3[2],cls_exp3[3],cls_exp3[4],cl_5])
weights_exp3,percentage_exp3 = get_weights(sc_exp3,[0,0,0,0,0],dists_exp3, n_classes=n_classes,dbscan=False)

  return self._call_impl(*args, **kwargs)
