In [None]:
'''
COTR demo for human face
We use an off-the-shelf face landmarks detector: https://github.com/1adrianb/face-alignment
'''
import argparse
import os
import time

import cv2
import numpy as np
import torch
import imageio
import matplotlib.pyplot as plt
# import torchprof

from COTR.utils import utils, debug_utils
from COTR.utils.stopwatch import StopWatch

from COTR.models import build_model
from COTR.options.options import *
from COTR.options.options_utils import *
from COTR.inference.inference_helper import triangulate_corr
from COTR.inference.sparse_engine import SparseEngine

from pytorch_memlab import MemReporter
from torchinfo import summary

utils.fix_randomness(0)
torch.set_grad_enabled(False)


def main(opt):
    model = build_model(opt)
    model = model.cuda()
    weights = torch.load(opt.load_weights_path, map_location='cpu')['model_state_dict']
    utils.safe_load_weights(model, weights)

    # eval(): switch to inference mode
    model = model.eval()
    
    return model
    


parser = argparse.ArgumentParser()
set_COTR_arguments(parser)
parser.add_argument('--out_dir', type=str, default=general_config['out'], help='out directory')
parser.add_argument('--load_weights', type=str, default=None, help='load a pretrained set of weights, you need to provide the model id')

args = ['--load_weights=default'] #MEMO:

opt = parser.parse_args(args)
opt.command = ' '.join(args)

layer_2_channels = {'layer1': 256,
                    'layer2': 512,
                    'layer3': 1024,
                    'layer4': 2048, }
opt.dim_feedforward = layer_2_channels[opt.layer]
if opt.load_weights:
    opt.load_weights_path = os.path.join(opt.out_dir, opt.load_weights, 'checkpoint.pth.tar')
print_opt(opt)
model = main(opt)


In [None]:
backbone = model.backbone[0].body
summary(model=backbone, input_size=[1, 3, 256, 256])

In [None]:
# for name, param in backbone.named_parameters():
#     print(name)

for key in backbone.state_dict().keys():
    if 'conv' in key:
        print(key)

In [None]:
backbone.conv1.weight.shape
backbone.state_dict()['layer1.0.bn1.weight']

In [None]:
import matplotlib as mpl
import matplotlib.colors as colors
from matplotlib.cm import ScalarMappable

param = backbone.conv1.weight

def nop(x):
    return x

def make_colmap(vmin, vmax):
    mm = max(abs(vmin), abs(vmax))
    print(f"mm:{mm}")

    # value
    nd_val0 = np.linspace(-mm, 0.0, 128)
    nd_val1 = np.flip(np.abs(nd_val0[:-1]))
    nd_val = np.hstack( (nd_val0, nd_val1) ).reshape([-1, 1])
    # col
    nd_col0 = np.linspace(1.0, 0.0, 128)
    nd_col1 = np.flip(nd_col0[:-1])
    nd_col = np.hstack( (nd_col0, nd_col1) ).reshape([-1, 1])
    # zero
    nd_zero = np.zeros( (128+127,1) )
    
    # R, G, B
    nd_r = np.hstack( (nd_col, nd_zero, nd_zero) )
    nd_g = np.hstack( (nd_zero, nd_col, nd_zero) )
    nd_b = np.hstack( (nd_zero, nd_zero, nd_col) )
    
    cmap_r = colors.ListedColormap( nd_r, "RED" )
    cmap_g = colors.ListedColormap( nd_g, "GREEN" )
    cmap_b = colors.ListedColormap( nd_b, "BLUE" )
    
    return [cmap_r, cmap_g, cmap_b]
    

def plot_img(name, param, pre_param=nop):
    l_std=[]
    l_mean=[]
    pmax = pre_param(param).max().item()
    pmin = pre_param(param).min().item()
    # OPTION
    # COLMAPS = make_colmap(pmin, pmax)
    COLMAPS=["Reds", "Greens", "Blues"]
    W=16
    H = int(param.shape[0]/W)
    gs = mpl.gridspec.GridSpec(H, W)
    # gs.update(hspace=0.1)
    for ch in range(3): #RGB
        print(f'ch={ch}')
        # pmax = pre_param(param[:,ch,:,:]).max().item()
        # pmin = pre_param(param[:,ch,:,:]).min().item()
        print(f"{name} pmin:{pmin}, pmax:{pmax}")
        fig = plt.figure(figsize=(W,H*1.1))
        for unit in range( param.shape[0] ):

            tmp = pre_param(param[unit,ch,:,:])

            std, mean = torch.std_mean(tmp)
            # print(tmp.shape, mean.item(), std.item())
            l_std.append(std.item())
            l_mean.append(mean.item())

            # img = tmp.permute(1,2,0).to('cpu').detach().numpy().copy()
            img = tmp.to('cpu').detach().numpy().copy()
            # img = (img-pmin)/(pmax-pmin)

            plt.subplot(gs[unit])
            plt.imshow(img, cmap=COLMAPS[ch], vmin=pmin, vmax=pmax)
            plt.title(f"{mean:.1e}")
            plt.axis('off')
        # title
        fig.suptitle(f"{name} ch={ch}")
            
        # make dummy field and remove labels and measures
        ax = fig.add_axes([0.2,0.0,0.6,0.3]) #X, Y, W, H
        ax.set_axis_off()

        # colorbar
        norm = colors.Normalize(vmin=pmin, vmax=pmax)
        mappable = ScalarMappable(cmap=COLMAPS[ch], norm=norm)
        mappable._A = []

        cb = fig.colorbar(mappable,ax=ax,aspect=90,pad=0.08,
                          shrink=0.9,orientation='horizontal')
        cb.ax.tick_params(labelsize=16)
            
        plt.show()
        fig.clf()
        plt.close()
        
# plot_img(param)

for key in backbone.state_dict().keys():
    if 'conv' in key:
        print(key)
        param = backbone.state_dict()[key]
        if param.shape[2]>1:
            plot_img(key, param)