In [None]:
import torch, detectron2
!nvcc --version
TORCH_VERSION = ".".join(torch.__version__.split(".")[:2])
CUDA_VERSION = torch.__version__.split("+")[-1]
print("torch: ", TORCH_VERSION, "; cuda: ", CUDA_VERSION)
print("detectron2:", detectron2.__version__)

In [None]:
import os
import sys
import logging
import argparse

pth = '/'.join(sys.path[0].split('/')[:-1])
sys.path.insert(0, pth)

from pprint import pprint
from PIL import Image
from copy import deepcopy
import numpy as np
np.random.seed(1)

home_dir = os.path.abspath(os.getcwd()+"/../")
sys.path.append(home_dir)
print(home_dir)

import torch
from torchvision import transforms
from torch.nn import functional as F
from utils.arguments import load_opt_command
from utils.misc import hook_metadata, hook_switcher, hook_opt

from detectron2.data import MetadataCatalog
from detectron2.utils.colormap import random_color
from detectron2.structures import Boxes, ImageList, Instances, BitMasks, BoxMode
from xdecoder.BaseModel import BaseModel
from xdecoder import build_model
from utils.visualizer import Visualizer
from utils.distributed import init_distributed

import warnings
warnings.filterwarnings(action='ignore')
logger = logging.getLogger(__name__)

from utils.arguments import load_opt_from_config_files, load_config_dict_to_opt
from datasets import build_evaluator, build_eval_dataloader
from xdecoder.utils import get_class_names

In [None]:
parser = argparse.ArgumentParser(description='Pretrain or fine-tune models for NLP tasks.')
parser.add_argument('--command', default="evaluate", help='Command: train/evaluate/train-and-evaluate')
parser.add_argument('--conf_files', nargs='+', help='Path(s) to the config file(s).')
parser.add_argument('--user_dir', help='Path to the user defined module for tasks (models, criteria), optimizers, and lr schedulers.')
parser.add_argument('--config_overrides', nargs='*', help='Override parameters on config with a json style string, e.g. {"<PARAM_NAME_1>": <PARAM_VALUE_1>, "<PARAM_GROUP_2>.<PARAM_SUBGROUP_2>.<PARAM_2>": <PARAM_VALUE_2>}. A key with "." updates the object in the corresponding nested dict. Remember to escape " in command line.')
parser.add_argument('--overrides', help='arguments that used to override the config file in cmdline', nargs=argparse.REMAINDER)

cmdline_args = parser.parse_args('')
cmdline_args.conf_files = [os.path.join(home_dir, "configs/xdecoder/svlp_focalt_lang.yaml")]
cmdline_args.overrides = ['WEIGHT', '../checkpoints/xdecoder_focalt_best_openseg.pt'] 
cmdline_args.overrides

opt = load_opt_from_config_files(cmdline_args.conf_files)

keys = [cmdline_args.overrides[idx*2] for idx in range(len(cmdline_args.overrides)//2)]
vals = [cmdline_args.overrides[idx*2+1] for idx in range(len(cmdline_args.overrides)//2)]
vals = [val.replace('false', '').replace('False','') if len(val.replace(' ', '')) == 5 else val for val in vals]
types = []
for key in keys:
    key = key.split('.')
    ele = opt.copy()
    while len(key) > 0:
        ele = ele[key.pop(0)]
    types.append(type(ele))

config_dict = {x:z(y) for x,y,z in zip(keys, vals, types)}
config_dict

load_config_dict_to_opt(opt, config_dict)
for key, val in cmdline_args.__dict__.items():
    if val is not None:
        opt[key] = val
opt = init_distributed(opt)

In [None]:
pretrained_pth = os.path.join(opt['WEIGHT'])
output_root = './output'
image_pth = '../images/animals.png'
print(pretrained_pth)

## Model

In [None]:
model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth).eval().cuda()

In [None]:
t = []
t.append(transforms.Resize(800, interpolation=Image.BICUBIC))
transform = transforms.Compose(t)

thing_classes = ['zebra','antelope','giraffe','ostrich','sky','water','grass','sand','tree']
thing_colors = [random_color(rgb=True, maximum=255).astype(np.int64).tolist() for _ in range(len(thing_classes))]
thing_dataset_id_to_contiguous_id = {x:x for x in range(len(thing_classes))}

MetadataCatalog.get("demo").set(
    thing_colors=thing_colors,
    thing_classes=thing_classes,
    thing_dataset_id_to_contiguous_id=thing_dataset_id_to_contiguous_id,
)

model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(thing_classes + ["background"], is_eval=False)
metadata = MetadataCatalog.get('demo')
model.model.metadata = metadata
model.model.sem_seg_head.num_classes = len(thing_classes)
model.eval()

In [None]:
model.model.sem_seg_head.pixel_decoder

In [None]:
num_layers = opt['MODEL']['BACKBONE']['FOCAL']['DEPTHS']
out_indices = opt['MODEL']['BACKBONE']['FOCAL']['OUT_INDICES']
focal_levels = opt['MODEL']['BACKBONE']['FOCAL']['FOCAL_LEVELS']
num_layers, out_indices, focal_levels

In [None]:
opt['MODEL']['DECODER']

## Backbone

In [None]:
pixel_mean = torch.Tensor([123.675, 116.280, 103.530]).view(-1, 1, 1).cuda()
pixel_std = torch.Tensor([58.395, 57.120, 57.375]).view(-1, 1, 1).cuda()

with torch.no_grad():
    image_ori = Image.open(image_pth).convert('RGB')
    width = image_ori.size[0]
    height = image_ori.size[1]
    image = transform(image_ori)
    image = np.asarray(image)
    image_ori = np.asarray(image_ori)
    images = torch.from_numpy(image.copy()).permute(2,0,1).cuda()
    batch_inputs = [{'image': images, 'height': height, 'width': width}]
    # outputs = model.forward(batch_inputs)

    images = [x["image"].to("cuda") for x in batch_inputs]
    images = [(x - pixel_mean) / pixel_std for x in images]
    
    images = ImageList.from_tensors(images, 32)
    print(f"Image shape: {images.tensor.shape}")

    ######################################################
    ## Backbone
    # output = model.model.backbone(images.tensor)
    # print(output.shape)
    ######################################################
    ## Backbone Inner Code

    # 1. Patch Embedding
    patch_embed = deepcopy(model.model.backbone.patch_embed)
    x = patch_embed(images.tensor)
    print(f"patch_embed shape: {x.shape}")
    Wh, Ww = x.size(2), x.size(3)
    x = x.flatten(2).transpose(1, 2)
    print(f"flatten shape: {x.shape}")

    # 2. Dropout
    pos_drop = deepcopy(model.model.backbone.pos_drop)
    x = pos_drop(x)

    # 3. Layers
    outs = {}
    for i in range(len(num_layers)):
        layer = model.model.backbone.layers[i]
        x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
        if i in out_indices:
            norm_layer = getattr(model.model.backbone, f'norm{i}')
            x_out = norm_layer(x_out)

            out = x_out.view(-1, H, W, model.model.backbone.num_features[i]).permute(0, 3, 1, 2).contiguous()
            outs["res{}".format(i + 2)] = out
            
    if len(model.model.backbone.out_indices) == 0:
        outs["res5"] = x_out.view(-1, H, W, model.model.backbone.num_features[i]).permute(0, 3, 1, 2).contiguous()
    
    print("\nBackbone output")
    outputs = {}
    for k in outs.keys():
        if k in model.model.backbone._out_features:
            outputs[k] = outs[k]
            print(f"{k}: {outs[k].shape}")

    ##############################################################################################################################
    ###### XDecoder Head ######
    ###### 1. Pixel Decoder ######
    # mask_features, transformer_encoder_features, multi_scale_features = model.model.sem_seg_head.pixel_decoder.forward_features(outputs)
    # print(mask_features.shape, transformer_encoder_features.shape, len(multi_scale_features))
    ##############################################################################################################################
    ## Pixel Decoder Inner Code

    multi_scale_features = []
    num_cur_levels = 0
    
    # Reverse feature maps into top-down order (from low to high resolution)
    for idx, f in enumerate(model.model.sem_seg_head.pixel_decoder.in_features[::-1]):
        x = outputs[f]
        lateral_conv = model.model.sem_seg_head.pixel_decoder.lateral_convs[idx]
        output_conv = model.model.sem_seg_head.pixel_decoder.output_convs[idx]
        if lateral_conv is None:
            transformer = model.model.sem_seg_head.pixel_decoder.input_proj(x)
            pos = model.model.sem_seg_head.pixel_decoder.pe_layer(x)
            # print(transformer.shape)
            transformer = model.model.sem_seg_head.pixel_decoder.transformer(transformer, None, pos)
            y = output_conv(transformer)
            # print(transformer.shape)
            # save intermediate feature as input to Transformer decoder
            transformer_encoder_features = transformer
        else:
            cur_fpn = lateral_conv(x)
            # Following FPN implementation, we use nearest upsampling here
            y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
            y = output_conv(y)
        if num_cur_levels < model.model.sem_seg_head.pixel_decoder.maskformer_num_feature_levels:
            multi_scale_features.append(y)
            num_cur_levels += 1

    mask_features = model.model.sem_seg_head.pixel_decoder.mask_features(y) if model.model.sem_seg_head.pixel_decoder.mask_on else None
    
    print("\nPixel Decoder output")
    print(f"mask features: {mask_features.shape}")
    print(f"transformer enoder features: {transformer_encoder_features.shape}")
    for multi_scale_feature in multi_scale_features:
        print(f"multi_scale_feature: {multi_scale_feature.shape}")

    ##############################################################################################################################
    ###### XDecoder Head ######
    ###### 2. Predictor ######
    predictions = model.model.sem_seg_head.predictor(multi_scale_features, mask_features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={})
    print("\npredictor output")
    print(predictions.keys())
    ##############################################################################################################################

In [None]:
for feature in multi_scale_features:
    print(feature.shape)

In [None]:
# with torch.no_grad():
#     image_ori = Image.open(image_pth).convert('RGB')
#     width = image_ori.size[0]
#     height = image_ori.size[1]
#     image = transform(image_ori)
#     image = np.asarray(image)
#     image_ori = np.asarray(image_ori)
#     images = torch.from_numpy(image.copy()).permute(2,0,1).cuda()
#     batch_inputs = [{'image': images, 'height': height, 'width': width}]
#     outputs = model.forward(batch_inputs)

In [None]:
# model.model.sem_seg_head.state_dict().keys()

In [None]:
# model.model.sem_seg_head.pixel_decoder

In [None]:
# model.model.sem_seg_head.predictor

In [None]:
# from copy import deepcopy

# # from xdecoder.backbone.focal import D2FocalNet
# focal_backbone = deepcopy(model.model.backbone)

# # from xdecoder.body.encoder.transformer_encoder_fpn import TransformerEncoderPixelDecoder
# pixel_decoder = deepcopy(model.model.sem_seg_head.pixel_decoder)

# # from xdecoder.body.decoder import TransformerEncoderPixelDecoder
# predictor = model.model.sem_seg_head.predictor


In [None]:
# model.model.sem_seg_head.predictor

In [None]:
# model.model.sem_seg_head.transformer_in_feature

In [None]:
# pixel_mean = torch.Tensor([123.675, 116.280, 103.530]).view(-1, 1, 1).cuda()
# pixel_std = torch.Tensor([58.395, 57.120, 57.375]).view(-1, 1, 1).cuda()

# # focal_backbone.eval()
# # evaluate() function in xdecoder_model.py
# with torch.no_grad():
#     image_ori = Image.open(image_pth).convert('RGB')
#     width = image_ori.size[0]
#     height = image_ori.size[1]
#     image = transform(image_ori)
#     image = np.asarray(image)
#     image_ori = np.asarray(image_ori)
#     images = torch.from_numpy(image.copy()).permute(2,0,1).cuda()
#     batch_inputs = [{'image': images, 'height': height, 'width': width}]
#     images = [x["image"].to('cuda') for x in batch_inputs]
#     images = [(x - pixel_mean) / pixel_std for x in images]
    
#     images = ImageList.from_tensors(images, 32)
#     img_bs = images.tensor.shape[0]

#     # Process
#     # 1. image -> focal_backbone -> features
#     # 2. features -> TransformerEncoderPixelDecoder -> mask_features, transformer_encoder_features, multi_scale_features

#     features = focal_backbone(images.tensor)
#     print(features['res5'].shape)
#     mask_features, transformer_encoder_features, multi_scale_features = pixel_decoder(features)
    
#     print(f"mask_features {mask_features.shape}")
#     print(f"transformer_encoder_features {transformer_encoder_features.shape}")
#     print(f"multi_scale_features {multi_scale_features[0].shape}")

#     # predictor를 수정하면 될듯
#     predictions = predictor(multi_scale_features, mask_features, None, None, None, "seg", {})
#     for key, value in predictions.items():
#         if not isinstance(predictions[key], list):
#             print(key, predictions[key].shape)


In [None]:
# t = []
# t.append(transforms.Resize(800, interpolation=Image.BICUBIC))
# transform = transforms.Compose(t)

# thing_classes = ['zebra','antelope','giraffe','ostrich','sky','water','grass','sand','tree']
# thing_colors = [random_color(rgb=True, maximum=255).astype(np.int64).tolist() for _ in range(len(thing_classes))]
# thing_dataset_id_to_contiguous_id = {x:x for x in range(len(thing_classes))}

# MetadataCatalog.get("demo").set(
#     thing_colors=thing_colors,
#     thing_classes=thing_classes,
#     thing_dataset_id_to_contiguous_id=thing_dataset_id_to_contiguous_id,
# )
# model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(thing_classes + ["background"], is_eval=False)
# metadata = MetadataCatalog.get('demo')
# model.model.metadata = metadata
# model.model.sem_seg_head.num_classes = len(thing_classes)

# with torch.no_grad():
#     image_ori = Image.open(image_pth).convert('RGB')
#     width = image_ori.size[0]
#     height = image_ori.size[1]
#     image = transform(image_ori)
#     image = np.asarray(image)
#     image_ori = np.asarray(image_ori)
#     images = torch.from_numpy(image.copy()).permute(2,0,1).cuda()

#     batch_inputs = [{'image': images, 'height': height, 'width': width}]
#     outputs = model.forward(batch_inputs)
#     visual = Visualizer(image_ori, metadata=metadata)

#     inst_seg = outputs[-1]['instances']
#     inst_seg.pred_masks = inst_seg.pred_masks.cpu()
#     inst_seg.pred_boxes = BitMasks(inst_seg.pred_masks > 0).get_bounding_boxes()
#     demo = visual.draw_instance_predictions(inst_seg) # rgb Image

#     if not os.path.exists(output_root):
#         os.makedirs(output_root)
#     demo.save(os.path.join(output_root, 'inst.png'))