In [1]:
import os

import tifffile
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
import wandb
from torchvision import transforms as T
import torchvision.transforms.functional as TF
from torch.utils.data import DataLoader, Dataset
from einops import rearrange
from kmeans_pytorch import kmeans

In [2]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mestorrs[0m ([33mtme-st[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
%load_ext autoreload

In [4]:
%autoreload 2

In [5]:
from dino_extended.data.utils import listfiles, extract_ome_tiff, get_ome_tiff_channels, make_pseudo
from dino_extended.data.multiplex import TileTransform, TileDataset, MultichannelAug
from dino_extended.models.dino import Dino
from dino_extended.models.vit import ViT, Recorder, Extractor

In [6]:
fps = sorted(listfiles('/data/estorrs/mushroom/data/test_registration/HT397B1/registered/', regex=r'ome.tiff$'))
fps

['/data/estorrs/mushroom/data/test_registration/HT397B1/registered/s1.ome.tiff',
 '/data/estorrs/mushroom/data/test_registration/HT397B1/registered/s2.ome.tiff']

In [7]:
for fp in fps:
    print(get_ome_tiff_channels(fps[0]))

  d = to_dict(os.fspath(xml), parser=parser, validate=validate)


['DAPI', 'CD8', 'Her2 (D)', 'GATA3 (D)', 'cKIT-(D)', 'Pan-Cytokeratin', 'GLUT1-(D)', 'Podoplanin', 'CD68 (D)', 'HLA-DR', 'Keratin 14', 'FoxP3', 'MGP-(D)', 'CD20-(D)', 'SMA-(D)', 'Ki67', 'Vimentin-(D)', 'PR-(D)', 'Bap1 (D)', 'CD45 (D)', 'ER', 'CD31', 'COX6c (D)', 'CK19', 'PLAT/tPA (D)']
['DAPI', 'CD8', 'Her2 (D)', 'GATA3 (D)', 'cKIT-(D)', 'Pan-Cytokeratin', 'GLUT1-(D)', 'Podoplanin', 'CD68 (D)', 'HLA-DR', 'Keratin 14', 'FoxP3', 'MGP-(D)', 'CD20-(D)', 'SMA-(D)', 'Ki67', 'Vimentin-(D)', 'PR-(D)', 'Bap1 (D)', 'CD45 (D)', 'ER', 'CD31', 'COX6c (D)', 'CK19', 'PLAT/tPA (D)']


In [8]:
channels = sorted(get_ome_tiff_channels(fps[0]))
channels

['Bap1 (D)',
 'CD20-(D)',
 'CD31',
 'CD45 (D)',
 'CD68 (D)',
 'CD8',
 'CK19',
 'COX6c (D)',
 'DAPI',
 'ER',
 'FoxP3',
 'GATA3 (D)',
 'GLUT1-(D)',
 'HLA-DR',
 'Her2 (D)',
 'Keratin 14',
 'Ki67',
 'MGP-(D)',
 'PLAT/tPA (D)',
 'PR-(D)',
 'Pan-Cytokeratin',
 'Podoplanin',
 'SMA-(D)',
 'Vimentin-(D)',
 'cKIT-(D)']

In [9]:
channels = [
#     'CD68 (D)', not working in this batch
    'CD20-(D)',
    ('CD31', 98.),
    ('CD45 (D)', 98.),
    'CD8',
    'CK19',
    ('DAPI', 98.),
    'ER',
    ('FoxP3', 99.5),
    'GATA3 (D)',
    'GLUT1-(D)',
    'HLA-DR',
    'Her2 (D)',
    ('Keratin 14', 98.),
    'Ki67',
    ('MGP-(D)', 98.),
#     'PLAT/tPA (D)',
    'PR-(D)',
    ('Pan-Cytokeratin', 95.),
    'Podoplanin',
    'SMA-(D)',
    'Vimentin-(D)',
    'cKIT-(D)'
]

channel_names = [c if not isinstance(c, tuple) else c[0] for c in channels]

In [10]:
from skimage.exposure import rescale_intensity
def preprocess_ome(fp, scale=None, default_contrast_pct=90., channels=None):
    if channels is not None:
        contrast_pcts = [default_contrast_pct if not isinstance(c, tuple) else c[1] for c in channels]
        channels = [c if not isinstance(c, tuple) else c[0] for c in channels]
    else:
        channels = sorted(get_ome_tiff_channels(fp))
        contrast_pcts = [default_contrast_pct] * len(channels)
        
    channel_to_img = extract_ome_tiff(fp, channels=channels)
    
    for c, contrast_pct in zip(channels, contrast_pcts):
        print(c)
        img = channel_to_img[c]
        try:
            vmax = np.percentile(img[img>0], (contrast_pct)) if np.count_nonzero(img) else 1.
            img[img>vmax] = vmax
            
#             img = rescale_intensity(img, in_range=(0., vmax)).astype(np.float32)
            img = img.astype(np.float32)
            img -= img.min()
            img /= img.max()
            channel_to_img[c] = torch.tensor(img, dtype=torch.float32).unsqueeze(dim=0)
#             plt.imshow(channel_to_img[c][0])
#             plt.title(c)
#             plt.show()
        except IndexError:
            print(f'channel {c} failed intensity rescaling')
            
    stacked = torch.concat([channel_to_img[c] for c in channels], dim=0)
    
    return stacked, channels

In [None]:
imgs = []
for fp in fps:
    print(fp)
    stacked, _ = preprocess_ome(fp, default_contrast_pct=90., channels=channels)
    imgs.append(stacked)
len(imgs)

/data/estorrs/mushroom/data/test_registration/HT397B1/registered/s1.ome.tiff
CD20-(D)
CD31
CD45 (D)
CD8
CK19
DAPI
ER
FoxP3


In [None]:
size = (256, 256)
scale = .5
ds = TileDataset(imgs, size=size, scale=scale)

In [None]:
tile = ds[0]
tile.shape

In [None]:
plt.imshow(tile[channel_names.index('DAPI')])

In [None]:
means, stds = imgs[0].mean(dim=(-2, -1)), imgs[0].std(dim=(-2, -1))
means, stds

In [None]:
batch_size = 64
dl = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=10)

In [None]:
b = next(iter(dl))
b.shape

In [None]:
n_image_channels = len(channel_names)

In [None]:
from pathlib import Path
project = 'codex_dino'
run_dir = f'/data/estorrs/DINO-extended/data/runs/HT397B1/{project}'
log_dir = os.path.join(run_dir, 'logs')
chkpt_dir = os.path.join(run_dir, 'chkpts')
Path(log_dir).mkdir(parents=True, exist_ok=True)
Path(chkpt_dir).mkdir(parents=True, exist_ok=True)

In [None]:
run = wandb.init(
  project=project,
)

In [None]:
# wandb.finish()

In [None]:
config = {
    'epochs': 1000,
    'lr': 3e-4,
    'batch_size': batch_size,
    'vit': {
        'image_size': size[0],
        'channels': n_image_channels,
        'patch_size': 32,
        'num_classes': 100,
        'dim': 1024,
        'depth': 6,
        'heads': 8,
        'mlp_dim': 2048
    },
    'dino': {
        'is_multichannel': True,
        'n_image_channels': n_image_channels,
        'means': list(means.numpy()),
        'stds': list(stds.numpy()),
        'image_size': size[0],
        'hidden_layer': 'to_latent',
        'projection_hidden_size': 256,      # projector network hidden dimension
        'projection_layers': 4,             # number of layers in projection network
        'num_classes_K': 65336,             # output logits dimensions (referenced as K in paper)
        'student_temp': 0.9,                # student temperature
        'teacher_temp': 0.04,               # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs
        'local_upper_crop_scale': 0.4,      # upper bound for local crop - 0.4 was recommended in the paper 
        'global_lower_crop_scale': 0.5,     # lower bound for global crop - 0.5 was recommended in the paper
        'moving_average_decay': 0.9,        # moving average of encoder - paper showed anywhere from 0.9 to 0.999 was ok
        'center_moving_average_decay': 0.9, # moving average of teacher centers - paper showed anywhere from 0.9 to 0.999 was ok
        'teacher_temps': list(np.linspace(.04, .07, 30))
    }
}

wandb.config = config

In [None]:
model = ViT(
    image_size = size[0],
    channels = n_image_channels,
    patch_size = 16,
    num_classes = 100,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048
)

In [None]:
learner = Dino(
    model,
    is_multichannel=True,
    n_image_channels=n_image_channels,
    means=means,
    stds=stds,
    image_size = size[0],
    hidden_layer = 'to_latent',        # hidden layer name or index, from which to extract the embedding
    projection_hidden_size = 256,      # projector network hidden dimension
    projection_layers = 4,             # number of layers in projection network
    num_classes_K = 65336,             # output logits dimensions (referenced as K in paper)
    student_temp = 0.9,                # student temperature
    teacher_temp = 0.04,               # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs
    local_upper_crop_scale = 0.4,      # upper bound for local crop - 0.4 was recommended in the paper 
    global_lower_crop_scale = 0.5,     # lower bound for global crop - 0.5 was recommended in the paper
    moving_average_decay = 0.9,        # moving average of encoder - paper showed anywhere from 0.9 to 0.999 was ok
    center_moving_average_decay = 0.9, # moving average of teacher centers - paper showed anywhere from 0.9 to 0.999 was ok
)

In [None]:
def get_cluster_images(model, embeddings):
    x = embeddings[:, 1:]
    orig_shape = x.shape
    x = rearrange(x, 'b p d -> (b p) d')
    cluster_ids_x, cluster_centers = kmeans(
        X=x, num_clusters=10, tol=1., distance='euclidean', device=torch.device('cuda:0')
    )
    reshaped = rearrange(cluster_ids_x, '(b p) -> b p', p=orig_shape[1])
    img = rearrange(reshaped, 'b (h w) -> b h w', h=int(np.sqrt(orig_shape[1])))
    
    img = img.to(torch.float32)
    img -= img.min()
    img /= img.max()
    
    return img

def log_media(model, dl, n_show=16, n_batches=5, channels=['DAPI', 'Pan-Cytokeratin', 'CD45 (D)']):
    embeddings = []
    imgs = None
    v = Extractor(model)
    with torch.no_grad():
        for i, b in enumerate(dl):
            if i == n_batches:
                break
            b = b.to(0)
            _, embs = v(b)
            embeddings.append(embs)
            
            if imgs is None:
                imgs = b[:n_show, [channel_names.index(c) for c in channels]]
    
    embeddings = torch.concat(tuple(embeddings), dim=0)
    cluster_imgs = get_cluster_images(model, embeddings)[:n_show].unsqueeze(dim=1).numpy()
    cluster_imgs = rearrange(cluster_imgs, 'b c h w -> b h w c')
    
    pseudo = [np.expand_dims(make_pseudo(img.detach().cpu().numpy()), 0) for img in imgs]
    pseudo = np.concatenate(pseudo, axis=0)
    
    wandb.log({
        'train/pseudo': [wandb.Image(x) for x in pseudo],
        'train/clustered': [wandb.Image(x) for x in cluster_imgs]
    })
    
    v = v.eject()
    

    
    return cluster_imgs, pseudo

In [None]:
len(dl)

In [None]:
learner = learner.to(0)

In [None]:
teacher_temps = np.linspace(.04, .07, 30)
teacher_temps

In [None]:
epochs = config['epochs']
lr = config['lr']
opt = torch.optim.Adam(learner.parameters(), lr=lr)

In [None]:
media_log_every = 10
chkpt_every = 50

In [None]:
for e in range(epochs):
    epoch_loss = None
    for b in dl:
        b = b.to(0)
        loss = learner(b)
        opt.zero_grad()
        loss.backward()
        opt.step()
        learner.update_moving_average() # update moving average of teacher encoder and teacher centers
        
        if epoch_loss is None:
            epoch_loss = loss.detach().cpu().item()
        else:
            epoch_loss += loss.detach().cpu().item()
    result = {"train/loss": epoch_loss / len(dl), 'n_steps': e * len(dl) * batch_size, 'epoch': e}
    print(result)
    wandb.log(result)
    if e % media_log_every == 0:
        log_media(model, dl)
    
    if e < 29:
        learner.teacher_temp = teacher_temps[e + 1]
        
    if e % chkpt_every == 0:
        torch.save(model.state_dict(), os.path.join(chkpt_dir, f'{e}.pt'))
    