In [1]:
from feeder.my_feeder import Feeder
import numpy as np
from gcn import GCN
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import glob
from tqdm.notebook import tqdm

In [2]:
working_dir = os.getcwd()

f_dim = 512

feat_path = os.path.join(working_dir, f'features/l2_512_k8_1_13300.0943/test/feat.json')
knn_graph_path = os.path.join(working_dir, f'features/l2_512_k8_1_13300.0943/test/knn_graph.json')
label_path = os.path.join(working_dir, f'features/l2_512_k8_1_13300.0943/test/label.json')
obj_type_path = os.path.join(working_dir, f'features/l2_512_k8_1_13300.0943/test/obj_type.json')
k_at_hop = [8, 5]
active_connection = 5
seed = 0

ckpt_weight_path = glob.glob(os.path.join(working_dir, f'logs/l2_{f_dim}_k8/*.ckpt'))
ckpt_weight_path = sorted(ckpt_weight_path, key=lambda x: int(x.split('_')[-1].replace('.ckpt', '')))[-1]
print(ckpt_weight_path)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

/home/jayson/gcn_clustering/logs/l2_512_k8/epoch_47.ckpt


In [3]:
dataset = Feeder(feat_path, 
                 knn_graph_path, 
                 label_path,
                 obj_type_path,
                 seed,
                 k_at_hop,
                 active_connection,
                 train=False)

net = GCN(in_dim=f_dim)
weight = torch.load(ckpt_weight_path)['state_dict']
net.load_state_dict(weight)
net.to(device)
net.eval()

gcn(
  (bn0): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
  (conv1): GraphConv(
    (agg): MeanAggregator()
  )
  (conv2): GraphConv(
    (agg): MeanAggregator()
  )
  (conv3): GraphConv(
    (agg): MeanAggregator()
  )
  (conv4): GraphConv(
    (agg): MeanAggregator()
  )
  (classifier): Sequential(
    (0): Linear(in_features=128, out_features=128, bias=True)
    (1): PReLU(num_parameters=128)
    (2): Linear(in_features=128, out_features=2, bias=True)
  )
)

In [4]:
i = 150
feat, adj, cid, h1id, node_list, gtmat, obj_mask = dataset[i]
obj_types = torch.tensor(dataset.obj_types[i])

In [5]:
feat.shape, adj.shape, cid.shape, h1id.shape, node_list.shape, gtmat.shape, obj_mask.shape

(torch.Size([38, 49, 512]),
 torch.Size([38, 49, 49]),
 torch.Size([38, 1]),
 torch.Size([38, 8]),
 torch.Size([38, 49]),
 torch.Size([38, 8]),
 torch.Size([38, 8]))

In [6]:
feat, adj, cid, h1id, gtmat, obj_mask = map(lambda x: x.to(device), (feat, adj, cid, h1id, gtmat, obj_mask))

In [7]:
with torch.no_grad():
    pred = net(feat, adj, h1id)

pred = F.softmax(pred, dim=1)
pred = pred.view(feat.shape[0], k_at_hop[0], 2)

In [8]:
def accuracy(pred, label, masks, thres):
    pred = (torch.argmax(pred > thres, dim=1) * masks).long()
    acc = torch.mean((pred == label).float())
    return acc

def accuracy(pred, label, masks, thres):
    pred = (torch.argmax(pred > thres, dim=1) * masks).long()
    acc = torch.mean((pred == label).float())
    return acc

In [9]:
pred_c = (pred > 0.9).float().argmax(-1)

In [10]:
pred_c.shape

torch.Size([38, 8])

In [11]:
(gtmat == (pred_c * obj_mask)).float().mean().item(), (gtmat == pred_c).float().mean().item()

(0.9769737124443054, 0.9769737124443054)

## Final Pred

In [12]:
def extract_clusters_from_matrix(pred, obj_mask, obj_types, thres=0.5, prediction=True):
    '''
    Args:
    - mat (torch.int64): [n_nodes, k1]
    '''
    if prediction:
        pred_confs = (pred[:, :, 1] * obj_mask).cpu() * (pred[:, :, 1] > thres).cpu()
    else:
        pred_confs = pred.float()
        
    a = torch.zeros(pred_confs.shape[0], pred_confs.shape[0])
    for r in range(pred_confs.shape[0]):
        n_real_ids = node_list[r][h1id[r].cpu()].long()
        a[r, n_real_ids] = pred_confs[r]
        
    pairs = []
    clus_mat = torch.zeros(pred_confs.shape[0], pred_confs.shape[0])
    while a.sum().item() != 0:
        r, c = list(zip(*torch.where(a == a.max())))[0]
        # Zero out
        a[r, c] = 0
        a[c, r] = 0
        r_type_ids = torch.where(obj_types == obj_types[r])[0]
        c_type_ids = torch.where(obj_types == obj_types[c])[0]
        a[r, c_type_ids] = 0
        a[c, r_type_ids] = 0
        if clus_mat[r, :].sum().item() < 3 and clus_mat[c, :].sum().item() < 3:
            clus_mat[r, c] = 1
            clus_mat[c, r] = 1
            pairs.append(set([r.item(), c.item()]))

    # Clean subsets
    clusters = []
    nodes2clus = {}
    visited_nodes = set()
    for (n1, n2) in pairs:
        if nodes2clus.get(n1) is None and nodes2clus.get(n2) is None:
            clus_id = len(clusters)
            nodes2clus[n1] = clus_id
            nodes2clus[n2] = clus_id
            clusters.append([n1, n2])
        elif nodes2clus.get(n1) is not None and nodes2clus.get(n2) is None:
            clus_id = nodes2clus[n1]
            if len(clusters[clus_id]) < 3:
                nodes2clus[n2] = clus_id
                clusters[clus_id].append(n2)
        elif nodes2clus.get(n1) is None and nodes2clus.get(n2) is not None:
            clus_id = nodes2clus[n2]
            if len(clusters[clus_id]) < 3:
                nodes2clus[n1] = clus_id
                clusters[clus_id].append(n1)
        elif nodes2clus.get(n1) is not None and nodes2clus.get(n2) is not None:  # merge 2 clusters
            n1_clus_id = nodes2clus[n1]
            n2_clus_id = nodes2clus[n2]
            if n1_clus_id == n2_clus_id:
                continue
            if len(clusters[n1_clus_id] + clusters[n2_clus_id]) <= 3:
                nodes2clus[n1] = n1_clus_id
                for ni in clusters[n2_clus_id]:
                    nodes2clus[ni] = n1_clus_id
                clusters[n1_clus_id] += clusters[n2_clus_id]
                clusters.pop(n2_clus_id)
    
    clusters = [set(i) for i in clusters]
    clusters = sorted(clusters, key=lambda x: max(x))
            
    return pairs, clusters

In [13]:
c_, pred_clusters = extract_clusters_from_matrix(pred, obj_mask, obj_types,)
_, gt_clusters = extract_clusters_from_matrix(gtmat.cpu(), obj_mask, obj_types, prediction=False)

In [14]:
pred_clusters

[{0, 1},
 {2, 3, 4},
 {5, 6},
 {7, 8},
 {9, 10},
 {11, 12},
 {13, 14, 15},
 {16, 17, 18},
 {19, 20, 21},
 {22, 23, 24},
 {25, 26},
 {27, 28},
 {29, 30},
 {31, 32},
 {33, 34, 35},
 {36, 37}]

In [15]:
gt_clusters

[{0, 1},
 {2, 3, 4},
 {5, 6},
 {7, 8},
 {9, 10},
 {11, 12},
 {13, 14, 15},
 {16, 17, 18},
 {19, 20, 21},
 {22, 23, 24},
 {25, 26},
 {27, 28},
 {29, 30},
 {31, 32},
 {33, 34, 35},
 {36, 37}]

In [16]:
overall_acc = 0
exact_match = 0
for i in tqdm(range(len(dataset))):
    feat, adj, cid, h1id, node_list, gtmat, obj_mask = dataset[i]
    obj_types = torch.tensor(dataset.obj_types[i])
    feat, adj, cid, h1id, gtmat, obj_mask = map(lambda x: x.to(device), (feat, adj, cid, h1id, gtmat, obj_mask))
    with torch.no_grad():
        pred = net(feat, adj, h1id)

    pred = F.softmax(pred, dim=1)
    pred = pred.view(feat.shape[0], h1id.shape[1], 2)
    
    c_, pred_clusters = extract_clusters_from_matrix(pred, obj_mask, obj_types,)
    _, gt_clusters = extract_clusters_from_matrix(gtmat.cpu(), obj_mask, obj_types, prediction=False)
    
    matches = 0
    for pred_c in pred_clusters:
        if pred_c in gt_clusters:
            matches += 1
            
    overall_acc += matches/len(gt_clusters)
    if matches == len(gt_clusters):
        exact_match += 1
    
print(f'Average Acc: {overall_acc/(i+1)} | Exact Match Acc: {exact_match/(i+1)}')

  0%|          | 0/1003 [00:00<?, ?it/s]

Average Acc: 0.9777334662678636 | Exact Match Acc: 0.8664007976071785


In [13]:
adj

tensor([[[0.0000, 0.1414, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.1414, 0.0000, 0.1291,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.1291, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.1414,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.1291,  ..., 0.0000, 0.0000, 0.0000],
         [0.1414, 0.1291, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.1443,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.1667,  ..., 0.0000, 0.0000, 0.0000],
         [0.1443, 0.1667, 0.0000,  ..., 0.0000, 0.0000, 0.