In [1]:
import os
import numpy as np
import pandas as pd

import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from tqdm import tqdm

from conch.open_clip_custom import create_model_from_pretrained, get_tokenizer, tokenize



In [2]:
# Read BCNB dataset

datafolder = '/mnt/d/data/BCNB/'

# Read train/val/test splits
def get_ids(filename):
    with open(f'{datafolder}/dataset-splitting/{filename}', 'r') as f:
        ids = f.readlines()
        ids = [int(x.strip()) for x in ids]
    return ids

train_ids = get_ids('train_id.txt')
val_ids = get_ids('val_id.txt')
test_ids = get_ids('test_id.txt')

In [3]:
# Read metadata
df_metadata = pd.read_excel(f'{datafolder}/patient-clinical-data.xlsx')
df_metadata


Unnamed: 0,Patient ID,Age(years),Tumour Size(cm),Tumour Type,ER,PR,HER2,HER2 Expression,Histological grading,Surgical,Ki67,Molecular subtype,Number of lymph node metastases,ALN status
0,1,77,3.0,Other type,Positive,Positive,Negative,0,,Axillary lymph node dissection,0.01,Luminal A,0,N0
1,2,39,3.5,Invasive ductal carcinoma,Negative,Negative,Negative,0,3.0,Sentinel lymph node biopsy,0.4,Triple negative,4,N+(>2)
2,3,52,3.0,Invasive ductal carcinoma,Positive,Positive,Negative,0,2.0,Axillary lymph node dissection,0.06,Luminal A,7,N+(>2)
3,4,60,2.3,Other type,Negative,Negative,Positive,3+,,Axillary lymph node dissection,0.6,HER2(+),0,N0
4,5,71,3.5,Invasive ductal carcinoma,Negative,Negative,Negative,0,,Axillary lymph node dissection,0.12,Triple negative,0,N0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1053,1054,60,2.0,Invasive ductal carcinoma,Positive,Positive,Positive,2+,,Axillary lymph node dissection,0.2,Luminal B,0,N0
1054,1055,63,2.0,Other type,Positive,Positive,Negative,1+,,Axillary lymph node dissection,0.1,Luminal A,0,N0
1055,1056,62,2.0,Invasive ductal carcinoma,Positive,Positive,Positive,3+,3.0,Axillary lymph node dissection,30-40%,HER2(+),0,N0
1056,1057,46,4.0,Invasive ductal carcinoma,Positive,Positive,Negative,2+,3.0,Axillary lymph node dissection,0.7,Luminal B,5,N+(>2)


In [4]:
# Read patches

data = []
for idx, row in df_metadata.iterrows():
    id = row['Patient ID']
    patches = os.listdir(f'{datafolder}/paper_patches/patches/{id}/')
    
    if id in train_ids:
        split = 'train'
    elif id in val_ids:
        split = 'val'
    elif id in test_ids:
        split = 'test'
    else:
        split = 'unknown'

    metadata = {
        'patient_id': row['Patient ID'],
        'age': row['Age(years)'],
        'tumor_size': row['Tumour Size(cm)'],
        'tumor_type': row['Tumour Type'],
        'er': row['ER'],
        'pr': row['PR'],	
        'her2': row['HER2'],
        'ki67': row['Ki67'],	
        'her2_expression': row['HER2 Expression'],
        'hist_grading': row['Histological grading'],
        'surgical': row['Surgical'],
        'molecular_subtype': row['Molecular subtype'],
        'lymph_node_metastases': row['Number of lymph node metastases'],	
        'aln_status': row['ALN status'],
        'split': split 
    }

    for patch in patches:
        patch_id = patch.split('.')[0]
        metadata['patch_id'] = patch_id
        metadata['patch_filename'] = f'{id}/{patch}'
        data.append(metadata)
    
df_patches = pd.DataFrame(data)

In [5]:
# get number of patients per split
df_patches.groupby('split').patient_id.nunique()

split
test     218
train    630
val      210
Name: patient_id, dtype: int64

In [6]:
len(df_patches)

76578

In [7]:
# Get embeddings
model_cfg = 'conch_ViT-B-16'
checkpoint_path = './checkpoints/CONCH/pytorch_model.bin'
model, preprocess = create_model_from_pretrained(model_cfg, checkpoint_path)

# _ = model.eval()

# We are only interensted in the ViT part of the model
# Since the config states that attentional_pool_caption is true, then the default forward function does not use the head, nor normalization
# TODO: Check if head and l2 normalization are used in finetuning
model_vit = model.visual
model_vit.eval()
model_vit.cuda()


  checkpoint = torch.load(checkpoint_path, map_location=map_location)


VisualModel(
  (trunk): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (patch_drop): Identity()
    (norm_pre): Identity()
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU(approximate='none')


In [8]:
preprocess

Compose(
    Resize(size=448, interpolation=bicubic, max_size=None, antialias=True)
    CenterCrop(size=(448, 448))
    <function _convert_to_rgb at 0x7f124914fd90>
    ToTensor()
    Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)

In [22]:
aug = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) # from conch preprocess 
])


class ImageExpressionDataset(Dataset):
    def __init__(self, df, image_dir, transform=None):
        # we want to give all patches per patient
        self.df = df
        
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        row = self.df.iloc[index]
        
        image = Image.open(f'{self.image_dir}/{row.patch_filename}')

        if self.transform is not None:
            image = self.transform(image)
        else:
            image = transforms.ToTensor()(image)
        
        return image
    


In [40]:
batch_size = 4

for patient_id in df_patches.patient_id.unique():
    df_patient = df_patches[df_patches.patient_id == patient_id]
    dataset = ImageExpressionDataset(df_patient, f'{datafolder}/paper_patches/patches', transform=aug)
    
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)

    output = None
    for img in tqdm(data_loader):
        img = img.cuda(non_blocking=True)
        out_1, _ = model_vit(img)

        if output is None:
            output = out_1.cpu().detach()
        else:
            output = torch.cat((output, out_1.cpu().detach()), dim=0)
        
    avg_emb = torch.mean(output, dim=0).numpy()
    
    # save embeddings per patient
    if not os.path.isdir(f'{datafolder}/paper_patches/patches/{model_cfg}'):
        os.makedirs(f'{datafolder}/paper_patches/patches/{model_cfg}')
        
    np.save(f'{datafolder}/paper_patches/patches/{model_cfg}/{patient_id}_emb.npy', avg_emb)
    break
        

100%|█████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.83it/s]


In [43]:
np.load(f'{datafolder}/paper_patches/patches/{model_cfg}/{patient_id}_emb.npy')

array([-4.78522629e-02,  9.65955853e-01,  1.26113147e-01,  9.18906152e-01,
        5.22686183e-01,  1.68825912e+00, -2.67322683e+00, -7.46284187e-01,
        1.74429789e-01, -8.36225629e-01,  5.76131821e-01, -1.14771426e+00,
        7.96293259e-01,  3.75898778e-01, -2.58326125e+00, -1.01727402e+00,
       -1.21100545e+00, -6.72673523e-01, -1.18650818e+00,  2.80674338e-01,
       -2.01750803e+00, -2.00845361e+00,  1.10961139e+00,  4.83080029e-01,
        1.86694539e+00, -6.10061362e-02, -8.75446618e-01, -8.23609412e-01,
        1.89954996e-01, -1.48449910e+00,  1.79805830e-01, -6.29979730e-01,
       -5.19608915e-01,  1.39089012e+00,  1.45456624e+00,  1.19226766e+00,
        1.76135814e+00,  6.43511534e-01,  9.99120414e-01, -1.02252571e-03,
        2.68281198e+00,  2.03658730e-01,  1.74051690e+00, -1.61075985e+00,
        5.84448099e-01, -1.75832525e-01,  1.10334232e-01,  4.79964375e-01,
        6.92903757e-01, -1.22373283e+00, -5.94353210e-03, -1.88434172e+00,
        6.98218465e-01, -