In [1]:
import hydra
from src.commons.constants import PROJECT_PATH
from omegaconf import DictConfig, OmegaConf

import numpy as np
import cv2
from PIL import Image
import torch
import pytorch_lightning as pl
import pandas as pd
import os
from copy import deepcopy
import torch.nn.functional as F
from torchmetrics import Metric
from tqdm import tqdm
import torch
import matplotlib.pyplot as plt
from torch.utils import data
from torch.nn.utils.rnn import pad_sequence

from src.commons.utils_io import load_sam
from src.commons.utils import to_numpy, SegAnyChangeVersion, show_img, show_pair_img, show_prediction_sample, resize
from src.models.commons.mask_process import extract_object_from_batch, binarize_mask
from src.commons.constants import IMG_SIZE
from src.data.process import generate_grid_prompt
from src.commons.utils import create_sample_grid_with_prompt, get_mask_with_prompt, fig2arr

In [4]:
def load_config():
    # Initialize the Hydra configuration
    hydra.initialize(config_path="../../configs", version_base=None)
    
    # Compose the configuration with the desired environment override
    cfg = hydra.compose(config_name="train", overrides=["experiment=probing_diff", "sam_type=small", "data=levir-cd"])
    
    return cfg

In [5]:
from hydra.core.global_hydra import GlobalHydra
GlobalHydra.instance().clear()
cfg = load_config()
print(OmegaConf.to_yaml(cfg))

data:
  name: levir-cd
  _target_: src.data.datamodule.CDDataModule
  params:
    prompt_type: sample
    n_prompt: 1
    loc: center
    batch_size: 2
    num_worker: 4
    pin_memory: false
    n_shape: 3
model:
  network:
    image_encoder:
      _target_: src.models.segment_anything.modeling.image_encoder_dev.ImageEncoderViT
      depth: 12
      embed_dim: 768
      img_size: 1024
      mlp_ratio: 4
      norm_layer: null
      num_heads: 12
      patch_size: 16
      qkv_bias: true
      use_rel_pos: true
      global_attn_indexes:
      - 2
      - 5
      - 8
      - 11
      window_size: 14
      out_chans: 256
    prompt_encoder:
      _target_: src.models.segment_anything.modeling.prompt_encoder_dev.PromptEncoder
      embed_dim: 256
      image_embedding_size:
      - 64
      - 64
      input_image_size:
      - 1024
      - 1024
      mask_in_chans: 16
    mask_decoder:
      transformer:
        _target_: src.models.segment_anything.modeling.transformer_dev.TwoWayTransfo

In [6]:
module = hydra.utils.instantiate(cfg.model.instance)

2024-08-02 12:08:03,241 - INFO ::  Weights loaded for : ['image_encoder']


In [12]:
module.model.prompt_encoder.point_embeddings[2].requires_grad_(False)

Embedding(1, 256)