In [1]:
import hydra
from src.commons.constants import PROJECT_PATH
from omegaconf import DictConfig, OmegaConf

import numpy as np
import cv2
from PIL import Image
import torch
import pytorch_lightning as pl
import pandas as pd
import os
from copy import deepcopy
import torch.nn.functional as F
from torchmetrics import Metric
from tqdm import tqdm
import torch
import matplotlib.pyplot as plt
from torch.utils import data
from torch.nn.utils.rnn import pad_sequence

from src.commons.utils_io import load_sam, make_path
from src.commons.utils import to_numpy, SegAnyChangeVersion, show_img, show_pair_img, show_prediction_sample, resize
from src.models.commons.mask_process import extract_object_from_batch, binarize_mask
from src.commons.constants import *
from src.data.process import generate_grid_prompt
from src.commons.utils import create_sample_grid_with_prompt, get_mask_with_prompt, fig2arr

### Load config for run

In [2]:
def load_config():
    # Initialize the Hydra configuration
    hydra.initialize(config_path="../configs", version_base=None)
    
    # Compose the configuration with the desired environment override
    cfg = hydra.compose(config_name="train", overrides=["experiment=probing_diff", "sam_type=small", "data=levir-cd"])
    
    return cfg

In [3]:
from hydra.core.global_hydra import GlobalHydra
GlobalHydra.instance().clear()
cfg = load_config()
print(OmegaConf.to_yaml(cfg))

data:
  name: levir-cd
  _target_: src.data.datamodule.CDDataModule
  params:
    prompt_type: sample
    n_prompt: 1
    loc: center
    batch_size: 2
    num_worker: 4
    pin_memory: false
    n_shape: 3
model:
  network:
    image_encoder:
      _target_: src.models.segment_anything.modeling.image_encoder_dev.ImageEncoderViT
      depth: 12
      embed_dim: 768
      img_size: 1024
      mlp_ratio: 4
      norm_layer: null
      num_heads: 12
      patch_size: 16
      qkv_bias: true
      use_rel_pos: true
      global_attn_indexes:
      - 2
      - 5
      - 8
      - 11
      window_size: 14
      out_chans: 256
    prompt_encoder:
      _target_: src.models.segment_anything.modeling.prompt_encoder_dev.PromptEncoder
      embed_dim: 256
      image_embedding_size:
      - 64
      - 64
      input_image_size:
      - 1024
      - 1024
      mask_in_chans: 16
    mask_decoder:
      transformer:
        _target_: src.models.segment_anything.modeling.transformer_dev.TwoWayTransfo

In [4]:
module = hydra.utils.instantiate(cfg.model.instance)

INIT VIT


2024-08-08 15:42:38,683 - INFO ::  Weights loaded for : ['image_encoder']


In [None]:
model = module.model

In [None]:
data_module = hydra.utils.instantiate(cfg.data)

In [None]:
data_module.setup(stage="fit")
train_dloader = data_module.train_dataloader()

In [None]:
batch = next(iter(train_dloader))

### Load config from run

Inside a Lightning checkpoint you’ll find:

* 16-bit scaling factor (if using 16-bit precision training)
* Current epoch
* Global step
* LightningModule’s state_dict
* State of all optimizers
* State of all learning rate schedulers
* State of all callbacks (for stateful callbacks)
* State of datamodule (for stateful datamodules)
* The hyperparameters (init arguments) with which the model was created
* The hyperparameters (init arguments) with which the datamodule was created 
* State of Loops

In [9]:
from src.models.magic_pen.task import MagicPenModule

In [10]:
_register_runs_ckpt = {
    "probing_concat":{
        "baseline": make_path(
            "2024-08-02_18-31-45/checkpoints/epoch_099.ckpt", 
            LOGS_PATH, 
            "beta",
            "beta_probing/levir-cd/vit-b"
        )
    }
}

_register_runs_params = {
        "probing_concat":{
        "baseline": make_path(
            "2024-08-02_18-31-45/.hydra/config.yaml",
            LOGS_PATH, 
            "beta",
            "beta_probing/levir-cd/vit-b"
        )
    }
}

In [11]:
model_type = "probing_concat"
model_name = "baseline"

In [12]:
import yaml
def load_yaml(path):
    with open(path, 'r') as fp:
        return yaml.safe_load(fp) 
        
config = load_yaml(_register_runs_params[model_type][model_name])

In [13]:
config

{'data': {'name': 'levir-cd',
  '_target_': 'src.data.datamodule.CDDataModule',
  'params': {'prompt_type': 'sample',
   'n_prompt': 1,
   'loc': 'center',
   'batch_size': 2,
   'num_worker': 4,
   'pin_memory': False,
   'n_shape': 3}},
 'model': {'network': {'image_encoder': {'_target_': 'src.models.segment_anything.modeling.image_encoder_dev.ImageEncoderViT',
    'depth': 12,
    'embed_dim': 768,
    'img_size': 1024,
    'mlp_ratio': 4,
    'norm_layer': None,
    'num_heads': 12,
    'patch_size': 16,
    'qkv_bias': True,
    'use_rel_pos': True,
    'global_attn_indexes': [2, 5, 8, 11],
    'window_size': 14,
    'out_chans': 256},
   'prompt_encoder': {'_target_': 'src.models.segment_anything.modeling.prompt_encoder_dev.PromptEncoder',
    'embed_dim': 512,
    'image_embedding_size': [64, 64],
    'input_image_size': [1024, 1024],
    'mask_in_chans': 16},
   'mask_decoder': {'transformer': {'_target_': 'src.models.segment_anything.modeling.transformer_dev.TwoWayTransformer'

In [14]:
params = config["model"]["instance"]["params"]

In [15]:
del params["sam_ckpt_path"]

In [27]:
layers = list(model.parameters())

In [33]:
layers[14].shape

torch.Size([3072])

In [55]:
module = MagicPenModule.load_from_checkpoint(
    _register_runs_ckpt[model_type][model_name],
    params=params,
    network = model)

ValueError: Please provide sam checkpoint

### Run inference