In [13]:
import warnings
warnings.filterwarnings('ignore')

import torch
from omegaconf import OmegaConf
from trainer import Trainer
from os.path import join as os_join
import utils


In [17]:
## read yaml file that stored the config of the model checkpoint
path = "checkpoint_configs/hypernet.yaml"
cfg = OmegaConf.load(path)
print(OmegaConf.to_yaml(cfg))

train:
  batch_strategy: random_instance
  resume_train: false
  resume_model: hypernet.pt
  use_amp: false
  checkpoints: ../STORE/adaptive_interface/checkpoints
  save_model: no_save
  clip_grad_norm: null
  batch_size: 128
  num_epochs: 15
  verbose_batches: 50
  seed: 125617
  debug: false
  adaptive_interface_epochs: 0
  adaptive_interface_lr: null
  swa: false
  swad: false
  swa_lr: 0.05
  swa_start: 5
  miro: false
  miro_lr_mult: 10.0
  miro_ld: 0.01
  tps_prob: 0.0
model:
  name: hyperconvnext
  pretrained: true
  pretrained_model_name: convnext_tiny.fb_in22k
  in_dim: null
  num_classes: null
  pooling: avg
  temperature: 0.07
  learnable_temp: false
  unfreeze_last_n_layers: -1
  unfreeze_first_layer: true
  first_layer: reinit_as_random
  reset_last_n_unfrozen_layers: false
  use_auto_rgn: false
  z_dim: 128
  hidden_dim: 256
  in_channel_names:
  - er
  - golgi
  - membrane
  - microtubules
  - mito
  - nucleus
  - protein
  - rna
  separate_emb: true
scheduler:
  name: c

In [3]:
## quick test on 1 image
cfg["eval"]["batch_size"] = 1 

In [4]:
## load trainer (including model) from the checkpoint 
## the model can be accessed by `trainer.model`

trainer = Trainer(cfg)
checkpoint_path = os_join(cfg.train.checkpoints, cfg.train.resume_model)
trainer._load_model(checkpoint_path)

device:  cuda
chunks [{'Allen': ['nucleus', 'membrane', 'protein']}, {'HPA': ['microtubules', 'protein', 'nucleus', 'er']}, {'CP': ['nucleus', 'er', 'rna', 'golgi', 'mito']}]
channels [3, 4, 5]
train:
  batch_strategy: random_instance
  resume_train: false
  resume_model: hypernet.pt
  use_amp: false
  checkpoints: ../STORE/adaptive_interface/checkpoints
  save_model: no_save
  clip_grad_norm: null
  batch_size: 128
  num_epochs: 15
  verbose_batches: 50
  seed: 125617
  debug: false
  adaptive_interface_epochs: 0
  adaptive_interface_lr: 0.04
  swa: false
  swad: false
  swa_lr: 0.05
  swa_start: 5
  miro: false
  miro_lr_mult: 10.0
  miro_ld: 0.01
  tps_prob: 0.0
model:
  name: hyperconvnext
  pretrained: true
  pretrained_model_name: convnext_tiny.fb_in22k
  in_dim: 5
  num_classes: 14
  pooling: avg
  temperature: 0.07
  learnable_temp: false
  unfreeze_last_n_layers: -1
  unfreeze_first_layer: true
  first_layer: reinit_as_random
  reset_last_n_unfrozen_layers: false
  use_auto_rg

15

In [6]:
## get the dataloader
chunk_name = "HPA"  ## "HPA", "CP", "Allen"
eval_loader = trainer.test_loaders[chunk_name]

In [7]:
## predict on the first batch
for bid, batch in enumerate(eval_loader, 1):
    x = utils.move_to_cuda(batch, trainer.device)
    output = trainer._forward_model(x, chunk_name)
    break

In [8]:
## output
output

tensor([[-2.3708e-01,  4.8151e-01,  1.3948e+00,  2.6062e-01, -1.0426e+00,
         -1.1895e+00, -8.3807e-02, -4.9599e-01, -5.3987e-01,  4.6434e-01,
          1.4282e+00, -8.1709e-01, -1.8891e-01, -7.6706e-01,  7.9759e-01,
         -8.7906e-01, -3.0676e-01,  4.2329e-01,  2.0100e+00, -3.2558e-02,
          8.6898e-01,  2.0236e-01, -1.8164e+00, -1.8529e+00, -1.9424e+00,
         -3.7470e-01, -1.9068e+00,  1.4251e-01,  9.2694e-01,  2.1006e+00,
          6.6473e-01, -2.8709e+00,  8.8675e-01,  1.6067e+00, -4.3679e-01,
          1.2245e+00, -1.2845e+00,  7.1122e-01, -1.2272e+00, -8.9891e-01,
          1.0262e+00,  3.5620e-01,  1.5320e+00, -3.2812e-01,  2.8977e-01,
         -4.2916e-01, -5.5874e-01, -6.4917e-01,  7.1124e-01, -6.1677e-01,
          7.2687e-01, -9.0863e-01,  1.1209e+00,  4.2012e-01,  2.8815e-02,
         -1.4981e+00,  8.1967e-01, -1.0982e+00, -4.7885e-01, -1.1211e-01,
          2.6869e-01, -5.2503e-02,  4.9140e-02,  3.4887e+00, -2.4342e+00,
          4.6784e-02, -4.1378e-01, -1.

In [9]:
output.shape

torch.Size([1, 768])

To evaluate on the whole dataset to get the metrics and visualization:

```trainer.eval_morphem70k(0)```

Note: you may want to use a larger eval batch_size for evaluation. This will take a while to complete the run.


For more details on the evaluation package, please visit https://github.com/broadinstitute/MorphEm