In [1]:
# !pip install info-nce-pytorch

In [1]:
import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from info_nce import InfoNCE 
from tqdm import tqdm

from conch.open_clip_custom import create_model_from_pretrained, get_tokenizer, tokenize

  from .autonotebook import tqdm as notebook_tqdm


### Load the model "create_model_from_pretrained"
TODO: Double check that image size of the conch is as expected (it runs with 224px images, but is it ok?)

In [9]:
import os

local_weights = False
model_cfg = 'conch_ViT-B-16'

if local_weights:
    checkpoint_path = './checkpoints/CONCH/pytorch_model.bin'
else:
    from dotenv import load_dotenv
    load_dotenv()  # take environment variables from .env.
    checkpoint_path = 'hf_hub:MahmoodLab/conch'
    hf_auth_token = os.getenv("HF_AUTH_TOKEN")
    if hf_auth_token is None:
        raise ValueError("HF_AUTH_TOKEN environment variable not set")
model, preprocess = create_model_from_pretrained(model_cfg, checkpoint_path, hf_auth_token=hf_auth_token)

# _ = model.eval()

# We are only interensted in the ViT part of the model
# Since the config states that attentional_pool_caption is true, then the default forward function does not use the head, nor normalization
# TODO: Check if head and l2 normalization are used in finetuning
model_vit = model.visual


In [8]:
print(model_vit)


VisualModel(
  (trunk): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (patch_drop): Identity()
    (norm_pre): Identity()
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU(approximate='none')


In [3]:
# Freeze all layers
for param in model_vit.parameters():
    param.requires_grad = False

# Unfreeze the last 3 layer
for param in model_vit.trunk.blocks[-3:].parameters():
    param.requires_grad = True



In [4]:
for name, param in model_vit.named_parameters():
    print(f"Layer: {name} | Trainable: {param.requires_grad}")

Layer: proj_contrast | Trainable: False
Layer: trunk.cls_token | Trainable: False
Layer: trunk.pos_embed | Trainable: False
Layer: trunk.patch_embed.proj.weight | Trainable: False
Layer: trunk.patch_embed.proj.bias | Trainable: False
Layer: trunk.blocks.0.norm1.weight | Trainable: False
Layer: trunk.blocks.0.norm1.bias | Trainable: False
Layer: trunk.blocks.0.attn.qkv.weight | Trainable: False
Layer: trunk.blocks.0.attn.qkv.bias | Trainable: False
Layer: trunk.blocks.0.attn.proj.weight | Trainable: False
Layer: trunk.blocks.0.attn.proj.bias | Trainable: False
Layer: trunk.blocks.0.norm2.weight | Trainable: False
Layer: trunk.blocks.0.norm2.bias | Trainable: False
Layer: trunk.blocks.0.mlp.fc1.weight | Trainable: False
Layer: trunk.blocks.0.mlp.fc1.bias | Trainable: False
Layer: trunk.blocks.0.mlp.fc2.weight | Trainable: False
Layer: trunk.blocks.0.mlp.fc2.bias | Trainable: False
Layer: trunk.blocks.1.norm1.weight | Trainable: False
Layer: trunk.blocks.1.norm1.bias | Trainable: False
La

In [5]:
# Define 3-layer MLP
st_raw_count_dims = 280
output_dims = 512

model_mlp = torch.nn.Sequential(
    torch.nn.Linear(st_raw_count_dims, output_dims*2),
    torch.nn.ReLU(),
    torch.nn.Linear(output_dims*2, output_dims*2),
    torch.nn.ReLU(),
    torch.nn.Linear(output_dims*2, output_dims)
) 

In [7]:
# Pytroch dataloader for images only
import h5py
import scanpy as sc

class ImageExpressionDataset(Dataset):
    def __init__(self, cases, image_dir, expression_dir, selected_genes=None, transform=None):
        self.cases = cases
        self.image_dir = image_dir
        self.expression_dir = expression_dir
        self.transform = transform
        self.selected_genes = selected_genes
        
        self.image_size = 224
        
        # TODO: This is not the optimal way to load the data, 
        # as it loads all the data in memory, and is very slow at start up

        # From dataset, read what cases to load
        # Load cases as anndata
        # Filter genes that are included from list
        self.data_df = self.load_data()
        print(f'Loaded {len(self.data_df)} patches with expression data')


    def load_data(self):
        # For each case, load the expression data and the patches
        dfs = []
        for case in self.cases:
            # Load the patches
            df_patches = self.load_patches(case)
            # Load the expression
            adata = self.load_expressions(case)
            # Merge the data
            df_patches.loc[:, 'expression'] = list(adata[df_patches.barcode, :].X)
            
            dfs.append(df_patches)

        return pd.concat(dfs)

    def load_patches(self, case):
        # Open the file in read mode
        with h5py.File(f'{self.image_dir}/{case}.h5', 'r') as file:
            # Get the data
            imgs = list(file['img'])
            barcodes = list(file['barcode'])
            coords = list(file['coords'])

        barcodes = [b[0].decode('utf-8') for b in barcodes]
        df = pd.DataFrame(
            {
                'barcode': barcodes,
                'coord': coords,
                'img': [Image.fromarray(im) for im in imgs]
            }
        )
        return df


    def load_expressions(self, case):
       
        adata = sc.read_h5ad(f'{self.expression_dir}/{case}.h5ad')
        adata.obs['batch'] = case  # Add filename as batch key

        # Filter to include only selected genes
        if self.selected_genes is not None:
            adata = adata[:, self.selected_genes]

        # Normalize and log transform
        # TODO: Should this be done for the whole dataset or per case?
        sc.pp.normalize_total(adata)
        # Log transformation
        sc.pp.log1p(adata)
        return adata


    def __len__(self):
        return len(self.data_df.index)
   

    def __getitem__(self, index):
        row = self.data_df.iloc[index]
        
        image = row.img
        expression = row.expression

        if self.transform is not None:
            image = self.transform(image)
        else:
            image = transforms.ToTensor()(image)

        return image, expression



In [8]:
preprocess

Compose(
    Resize(size=448, interpolation=bicubic, max_size=None, antialias=True)
    CenterCrop(size=(448, 448))
    <function _convert_to_rgb at 0x7ff3ad524c10>
    ToTensor()
    Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)

In [9]:
# Create an instance of your dataset
import json
aug = transforms.Compose([
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip()
    #transforms.Resize((224, 224)),
    #transforms.ToTensor(),
    #transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) # from conch preprocess
])

aug.transforms.extend(preprocess.transforms)

cases = ['TENX99', 'TENX96', 'TENX95', 'NCBI783', 'NCBI785']
expression_dir = '/home/hu-eki/Projects/HEST/tutorials/hest_data/st'
image_dir = '/home/hu-eki/Projects/HEST/tutorials/hest_data/patches'
selected_genes = json.load(open('filtered_genes.json'))

dataset = ImageExpressionDataset(cases=cases, image_dir=image_dir, expression_dir=expression_dir,
                                 selected_genes=selected_genes, transform=aug)

# # Create a PyTorch DataLoader
batch_size = 32 # TODO: change to 1024
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

  view_to_actual(adata)
  view_to_actual(adata)
  view_to_actual(adata)
  view_to_actual(adata)


Loaded 41893 patches with expression data


  view_to_actual(adata)


In [10]:
# Test dataloader
for image, expression in data_loader:
    print(image.shape)
    print(expression.shape)
    break

torch.Size([32, 3, 448, 448])
torch.Size([32, 280])


In [11]:
# TODO: Add layerwise learning rate decay rate of 0.7 for ViT model
# Define different learning rates for each layer
lr_decay = 0.7
lr_layer_3 = 1e-4
lr_layer_2 = lr_layer_3 * lr_decay
lr_layer_1 = lr_layer_2 * lr_decay

# Create parameter groups with different learning rates
param_groups = [
    {'params': model_vit.trunk.blocks[-3].parameters(), 'lr': lr_layer_1},
    {'params': model_vit.trunk.blocks[-2].parameters(), 'lr': lr_layer_2},
    {'params': model_vit.trunk.blocks[-1].parameters(), 'lr': lr_layer_3}
]

# Pass the parameter groups to the optimizer
optimizer_vit = torch.optim.Adam(param_groups)

# Set up the optimizer to only optimize the paramters that are trainable
# optimizer_vit = torch.optim.Adam(filter(lambda p: p.requires_grad, model_vit.parameters()), lr=0.0004)
# MLP optimizer
optimizer_mlp = torch.optim.Adam(model_mlp.parameters(), lr=0.003)

epochs = 50

scheduler_vit = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_vit, T_max=epochs)
scheduler_mlp = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_mlp, T_max=epochs)

In [12]:
temperature = 0.02 # according to the paper
# info-nce-pytorch
infoloss = InfoNCE(temperature=temperature, negative_mode='unpaired')

In [13]:
model_vit.cuda()
model_mlp.cuda()

Sequential(
  (0): Linear(in_features=280, out_features=1024, bias=True)
  (1): ReLU()
  (2): Linear(in_features=1024, out_features=1024, bias=True)
  (3): ReLU()
  (4): Linear(in_features=1024, out_features=512, bias=True)
)

In [None]:

for epoch in range(1, epochs + 1):

    model_vit.train()
    model_mlp.train()

    total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader)
    for img, expr in train_bar:
        img, expr = img.cuda(non_blocking=True), expr.cuda(non_blocking=True)
        out_1, _ = model_vit(img)
        out_2 = model_mlp(expr)
        
        loss = infoloss(out_1, out_2)
        optimizer_vit.zero_grad()
        optimizer_mlp.zero_grad()
        loss.backward()
        optimizer_vit.step()
        optimizer_mlp.step()

        total_num += batch_size
        total_loss += loss.item() * batch_size
        train_loss = total_loss / total_num
        train_bar.set_description('Train Epoch: [{}/{}] Loss: {:.4f}'.format(epoch, epochs, train_loss))

        
    
    scheduler_vit.step()
    scheduler_mlp.step()
    torch.save(model_vit.state_dict(), 'model_vit_50_preprocess.pth')

Train Epoch: [1/50] Loss: 2.2398:   4%|▍         | 50/1310 [00:34<14:20,  1.46it/s]


KeyboardInterrupt: 

In [3]:
samples = 5
for sample in range(1, samples+1):
    torch.save(model_vit.state_dict(), f'model_vit_base_id_{sample}.pth')

In [7]:
model, preprocess = create_model_from_pretrained(model_cfg, checkpoint_path, hf_auth_token=hf_auth_token)
model_vit = model.visual
torch.save(model_vit.state_dict(), f'model_vit_base_id_{9}.pth')
preprocess
hf_auth_token

  checkpoint = torch.load(checkpoint_path, map_location=map_location)


'your_hugging_face_token'

In [None]:
import os
import torch
from conch.open_clip_custom import create_model_from_pretrained

# Set up the configuration
local_weights = False
model_cfg = 'conch_ViT-B-16'

if local_weights:
    checkpoint_path = './checkpoints/CONCH/pytorch_model.bin'
else:
    checkpoint_path = 'hf_hub:MahmoodLab/conch'
    hf_auth_token = os.getenv("HF_AUTH_TOKEN")
    if hf_auth_token is None:
        raise ValueError("HF_AUTH_TOKEN environment variable not set")

# Load the model
model, preprocess = create_model_from_pretrained(model_cfg, checkpoint_path, hf_auth_token=hf_auth_token)

# We are only interested in the ViT part of the model
model_vit = model.visual

# Print the model to verify
print(model_vit)

: 