In [1]:
import warnings
warnings.filterwarnings('ignore')

import torch
from omegaconf import OmegaConf
from trainer import Trainer
from os.path import join as os_join
import utils

First, download the checkpoints from this [link](https://drive.google.com/drive/u/1/folders/1_xVgzfdc6H9ar4T5bd1jTjNkrpTwkSlL).

In [2]:
## read the config file
path = "checkpoint_configs/hypernet.yaml"
cfg = OmegaConf.load(path)
print(OmegaConf.to_yaml(cfg))

train:
  batch_strategy: random_instance
  resume_train: false
  resume_model: hypernet.pt
  use_amp: false
  checkpoints: ../STORE/adaptive_interface/checkpoints
  save_model: no_save
  clip_grad_norm: null
  batch_size: 128
  num_epochs: 15
  verbose_batches: 50
  seed: 125617
  debug: false
  adaptive_interface_epochs: 0
  adaptive_interface_lr: null
  swa: false
  swad: false
  swa_lr: 0.05
  swa_start: 5
  miro: false
  miro_lr_mult: 10.0
  miro_ld: 0.01
  tps_prob: 0.0
model:
  name: hyperconvnext
  pretrained: true
  pretrained_model_name: convnext_tiny.fb_in22k
  in_dim: null
  num_classes: null
  pooling: avg
  temperature: 0.07
  learnable_temp: false
  unfreeze_last_n_layers: -1
  unfreeze_first_layer: true
  first_layer: reinit_as_random
  reset_last_n_unfrozen_layers: false
  use_auto_rgn: false
  z_dim: 128
  hidden_dim: 256
  in_channel_names:
  - er
  - golgi
  - membrane
  - microtubules
  - mito
  - nucleus
  - protein
  - rna
  separate_emb: true
scheduler:
  name: c

In [3]:
## quick test on 1 image
cfg["eval"]["batch_size"] = 1 

## change the checkpoint path to the `model_checkpoints` that was downloaded. for example:
cfg["train"]["checkpoints"] =  "model_checkpoints"

In [4]:
## load trainer (including model) from the checkpoint 
## the model can be accessed by `trainer.model`

trainer = Trainer(cfg)
checkpoint_path = os_join(cfg.train.checkpoints, cfg.train.resume_model)
trainer._load_model(checkpoint_path)

device:  cuda
chunks [{'Allen': ['nucleus', 'membrane', 'protein']}, {'HPA': ['microtubules', 'protein', 'nucleus', 'er']}, {'CP': ['nucleus', 'er', 'rna', 'golgi', 'mito']}]
channels [3, 4, 5]
train:
  batch_strategy: random_instance
  resume_train: false
  resume_model: hypernet.pt
  use_amp: false
  checkpoints: ../STORE/adaptive_interface/checkpoints
  save_model: no_save
  clip_grad_norm: null
  batch_size: 128
  num_epochs: 15
  verbose_batches: 50
  seed: 125617
  debug: false
  adaptive_interface_epochs: 0
  adaptive_interface_lr: 0.04
  swa: false
  swad: false
  swa_lr: 0.05
  swa_start: 5
  miro: false
  miro_lr_mult: 10.0
  miro_ld: 0.01
  tps_prob: 0.0
model:
  name: hyperconvnext
  pretrained: true
  pretrained_model_name: convnext_tiny.fb_in22k
  in_dim: 5
  num_classes: 14
  pooling: avg
  temperature: 0.07
  learnable_temp: false
  unfreeze_last_n_layers: -1
  unfreeze_first_layer: true
  first_layer: reinit_as_random
  reset_last_n_unfrozen_layers: false
  use_auto_rg

15

In [5]:
## get the dataloader
chunk_name = "HPA"  ## "HPA", "CP", "Allen"
eval_loader = trainer.test_loaders[chunk_name]

In [6]:
## predict on the first batch
trainer.model.eval()
with torch.inference_mode():
    for bid, batch in enumerate(eval_loader, 1):
        x = utils.move_to_cuda(batch, trainer.device)
        output = trainer._forward_model(x, chunk_name)
        break

In [7]:
## output
output.shape

torch.Size([1, 768])

To evaluate on the whole dataset to get the metrics and visualization:

```trainer.eval_morphem70k(0)```

Note: you may want to use a larger eval batch_size for evaluation. This will take a while to complete the run.


For more details on the evaluation package, please visit https://github.com/broadinstitute/MorphEm