In [1]:
import numpy as np
import glob
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D 
import json
import os
import pytorch_lightning as pl
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.utils import make_grid 
from torchvision import models
from PIL import Image
from pytorch_lightning.callbacks import early_stopping, model_checkpoint, ProgressBar
from pl_bolts.models.self_supervised import Moco_v2, BYOL
from pl_bolts.transforms.dataset_normalizations import imagenet_normalization
import tqdm
from sklearn.decomposition import PCA
import gc
from torchvision.io import ImageReadMode, read_image
import torchvision
import threading, queue
from concurrent.futures import ThreadPoolExecutor

%matplotlib inline

In [2]:
#We would like to visualize the latent space using different Encoders
gc.collect()
torch.cuda.empty_cache()

In [3]:
# first we have to create a dataset that loads slide by slide

  
    
class Camelyon16PreprocesseSlidedDataset(torch.utils.data.Dataset):
    """
    Dataset of unlabelled patches
    """
    
    def __init__(self, csv_file, transforms=None):
        self.data = pd.read_csv(csv_file)
        self.data["slide"] = self.data["path"].str.split("/", expand=True)[1]
        groups = self.data.groupby("slide")
        self.slidenames = np.unique(self.data["slide"])
        self.patchesByslide = list(groups.groups.values())
        self.transforms = transforms
        
    def __len__(self):
        return len(self.patchesByslide)

    def __getitem__(self, idx):
        patches = self.data.iloc[self.patchesByslide[idx]]

#         ids = np.arange(0, len(patches))
#         np.random.shuffle(ids)
#         patches = patches.iloc[ids]

        imgs = torch.stack([read_image(path, ImageReadMode.RGB) for path in patches["path"]])/255

        global_label = patches["global_class"].iloc[0]
        local_labels = patches["local_class"].to_numpy()
        
        imgs = self.transforms(imgs)
        
        return imgs, (global_label, local_labels)
    
    def getpatchesidx(self, idx):
        return self.data.iloc[self.patchesByslide[idx]], self.slidenames[idx], self.patchesByslide[idx]
    
        
    def getbatchitem(self, patches, ids, batch_idx=0, batch_size=256):

        if(len(patches)//batch_size < batch_idx):
            print(batch_idx)
            return None
        
        patches = patches.iloc[batch_idx*batch_size:(batch_idx+1)*batch_size]
        imgs = torch.stack([read_image(path, ImageReadMode.RGB) for path in patches["path"]])/255
        global_label = patches["global_class"].iloc[0]
        local_labels = patches["local_class"].to_numpy()
        
        imgs = self.transforms(imgs)

        return (imgs, (global_label, local_labels), ids[batch_idx*batch_size:(batch_idx+1)*batch_size])
    

def getResnet18():
    MODEL_PATH = 'models/encoder/resnet18.ckpt'
    RETURN_PREACTIVATION = True  # return features from the model, if false return classification logits
    NUM_CLASSES = 2  # only used if RETURN_PREACTIVATION = False


    def load_model_weights(model, weights):

        model_dict = model.state_dict()
        weights = {k: v for k, v in weights.items() if k in model_dict}
        if weights == {}:
            print('No weight could be loaded..')
        model_dict.update(weights)
        model.load_state_dict(model_dict)

        return model


    model = torchvision.models.__dict__['resnet18'](pretrained=False)

    state = torch.load(MODEL_PATH, map_location='cuda:0')

    state_dict = state['state_dict']
    for key in list(state_dict.keys()):
        state_dict[key.replace('model.', '').replace('resnet.', '')] = state_dict.pop(key)

    model = load_model_weights(model, state_dict)

    if RETURN_PREACTIVATION:
        model.fc = torch.nn.Sequential()
    else:
        model.fc = torch.nn.Linear(model.fc.in_features, NUM_CLASSES)

    return model

In [4]:
data_transforms = transforms.Compose([imagenet_normalization()])
dataset = Camelyon16PreprocesseSlidedDataset("processed_data/data.csv", data_transforms)

In [5]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7efa9997d3a0>

In [6]:
np.unique(dataset.slidenames)

array(['normal_001', 'normal_002', 'normal_003', 'normal_004',
       'normal_005', 'normal_006', 'normal_007', 'normal_008',
       'normal_009', 'normal_010', 'normal_011', 'normal_012',
       'normal_013', 'normal_014', 'normal_015', 'normal_016',
       'normal_017', 'normal_018', 'normal_019', 'normal_020',
       'normal_021', 'normal_022', 'normal_023', 'normal_024',
       'normal_025', 'normal_026', 'normal_027', 'normal_028',
       'normal_029', 'normal_030', 'normal_031', 'normal_032',
       'normal_033', 'normal_034', 'normal_035', 'normal_036',
       'normal_037', 'normal_038', 'normal_039', 'normal_040',
       'normal_041', 'normal_042', 'normal_043', 'normal_044',
       'normal_045', 'normal_046', 'normal_047', 'normal_048',
       'normal_049', 'normal_050', 'normal_051', 'normal_052',
       'normal_053', 'normal_054', 'normal_055', 'normal_056',
       'normal_057', 'normal_058', 'normal_059', 'normal_060',
       'normal_061', 'normal_062', 'normal_063', 'norma

In [14]:
# model = getResnet18()
model = models.resnet18(True)
model = nn.Sequential(*list(model.children())[:-1]).cpu().eval()

In [None]:
# state_dict = torch.load("models/encoder/mocov3.pth")

In [None]:
# state_dict.keys()

In [None]:
# model = Moco_v2("resnet50")
# model = nn.Sequential(*list(model.encoder_q.children())[:-1]).cpu().eval()
# model.load_state_dict(state_dict)

In [15]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device).eval()
HOME = os.getcwd()

In [9]:
HOME

'/home/lab/WORK'

In [10]:
def encode(model, idx, device="cpu", save_encoding=None, batch_size=512, num_workers=4):
    
    patches, slidename, ids = dataset.getpatchesidx(idx)

    encodded = []
    labels = []
    
    
    available_batches = queue.Queue()
    prefetched = queue.Queue(maxsize=4)
    
    [available_batches.put(i) for i in range(len(patches)//batch_size)]
    
    def worker():
        while not available_batches.empty():
            item = available_batches.get()
            prefetched.put(dataset.getbatchitem(patches, ids, item, batch_size))
            available_batches.task_done() 

    with ThreadPoolExecutor() as executor:
        [executor.submit(worker) for i in range(num_workers)]
 
    
        for i in tqdm.tqdm(range(len(patches)//batch_size)):
            imgs, (global_label, local_labels), ids = prefetched.get()
            imgs = imgs.to(device)
            encodded.append(model(imgs).detach().cpu().numpy().reshape(-1, 2048))
            labels.append(local_labels)

        available_batches.join()
        
        encodded=np.concatenate(encodded,0)
        labels=np.concatenate(labels, 0)
        
        if save_encoding is not None:
            with open(os.path.join(save_encoding, f"{slidename}.npz"), "wb") as f:

                np.savez_compressed(f, imgs=encodded, labels=labels)

In [16]:
# state_dict = torch.load("models/encoder/mocov3.pth")
# model = Moco_v2("resnet50")
# model = nn.Sequential(*list(model.encoder_q.children())[:-1]).cpu().eval()
# model.load_state_dict(state_dict)

save_encoding=os.path.join(HOME, "encoded_data_imagenet")
batch_size=512
num_workers=8

In [None]:
for idx in range(len(dataset)):
    encode(model, idx, device=device, save_encoding=save_encoding, batch_size=batch_size, num_workers=num_workers)

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
100%|███████████████████████████████████████████| 18/18 [00:04<00:00,  3.64it/s]
100%|███████████████████████████████████████████| 19/19 [00:05<00:00,  3.38it/s]
100%|███████████████████████████████████████████| 53/53 [00:14<00:00,  3.63it/s]
100%|███████████████████████████████████████████| 12/12 [00:03<00:00,  3.01it/s]
100%|███████████████████████████████████████████| 15/15 [00:04<00:00,  3.23it/s]
100%|███████████████████████████████████████████| 11/11 [00:03<00:00,  2.96it/s]
100%|███████████████████████████████████████████| 30/30 [00:08<00:00,  3.48it/s]
100%|█████████████████████████████████████████████| 5/5 [00:02<00:00,  2.25it/s]
100%|███████████████████████████████████████████| 68/68 [00:18<00:00,  3.67it/s]
100%|███████████████████████████████████████████| 20/20 [00:06<00:00,  3.30it/s]
100%|█████████████████████████████████████████| 110/110 [00:28<00:00,  3.86it/s]
100%|████████████████████

In [None]:
save_encoding=os.path.join(HOME, "encoded_data_resnet18")
model = getResnet18()

In [None]:
for idx in range(len(dataset))[259:261]:
    encode(model, idx, device=device, save_encoding=save_encoding, batch_size=batch_size, num_workers=num_workers)

In [None]:
torchsummary.summary(model, (3, 224, 224))