In [1]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pickle
from skimage.segmentation import slic
import scipy.ndimage
import scipy.spatial
import torch
from torchvision import datasets
from torchvision import datasets
import sys
sys.path.append("../")
from chebygin import ChebyGIN
from extract_superpixels import process_image
from graphdata import comput_adjacency_matrix_images
from train_test import load_save_noise
from utils import list_to_torch, data_to_device, normalize_zero_one

# TRIANGLES

In [2]:
data_dir = '/scratch/ssd/data/graph_attention_pool/'
checkpoints_dir = '../checkpoints'
device = 'cuda'

with open('%s/random_graphs_triangles_test.pkl' % data_dir, 'rb') as f:
    data = pickle.load(f)
    
print(data.keys())
targets = torch.from_numpy(data['graph_labels']).long()
Node_degrees = [np.sum(A, 1).astype(np.int32) for A in data['Adj_matrices']]

feature_dim = data['Max_degree'] + 1
node_features = []
for i in range(len(data['Adj_matrices'])):
    N = data['Adj_matrices'][i].shape[0]
    D_onehot = np.zeros((N, feature_dim ))
    D_onehot[np.arange(N), Node_degrees[i]] = 1
    node_features.append(D_onehot)

dict_keys(['Adj_matrices', 'GT_attn', 'graph_labels', 'N_edges', 'Max_degree'])


In [3]:
def acc(pred):
    n = len(pred)
    return torch.mean((torch.stack(pred).view(n) == targets[:len(pred)].view(n)).float()).item() * 100

def test(model, index, show_img=False):    
    N_nodes = data['Adj_matrices'][index].shape[0]
    mask = torch.ones(1, N_nodes, dtype=torch.uint8)
    x = torch.from_numpy(node_features[index]).unsqueeze(0).float() 
    A = torch.from_numpy(data['Adj_matrices'][index].astype(np.float32)).float().unsqueeze(0)
    y, other_outputs = model(data_to_device([x, A, mask, -1, {'N_nodes': torch.zeros(1, 1) + N_nodes}], 
                                            device))    
    y = y.round().long().data.cpu()[0][0]
    alpha = other_outputs['alpha'][0].data.cpu() if 'alpha' in other_outputs else []        
    return y, alpha


# This function returns predictions for the entire clean and noise test sets
def get_predictions(model_path):
    state = torch.load(model_path)
    args = state['args']
    model = ChebyGIN(in_features=14,
                     out_features=1,
                     filters=args.filters,
                     K=args.filter_scale,
                     n_hidden=args.n_hidden,
                     aggregation=args.aggregation,
                     dropout=args.dropout,
                     readout=args.readout,
                     pool=args.pool,
                     pool_arch=args.pool_arch)
    model.load_state_dict(state['state_dict'])
    model = model.eval().to(device)
#     print(model)    

    # Get predictions
    pred, alpha = [], []
    for index in range(len(data['Adj_matrices'])):
        y = test(model, index, index == 0)
        pred.append(y[0])
        alpha.append(y[1])
        if len(pred) % 1000 == 0:
            print('{}/{}, acc on the combined test set={:.2f}%'.format(len(pred), len(data['Adj_matrices']), acc(pred)))
    return pred, alpha

## Weakly-supervised attention model

In [4]:
pred, alpha = get_predictions('%s/checkpoint_triangles_230187_epoch100_seed0000111.pth.tar' % checkpoints_dir)

ChebyGINLayer torch.Size([64, 98]) tensor([0.5471, 0.6374, 0.5764, 0.5645, 0.5911, 0.6126, 0.6002, 0.6030, 0.5544,
        0.5920], grad_fn=<SliceBackward>)
ChebyGINLayer torch.Size([32, 128]) tensor([0.5650, 0.6002, 0.5687, 0.5424, 0.6147, 0.5745, 0.5907, 0.5649, 0.5473,
        0.5304], grad_fn=<SliceBackward>)
ChebyGINLayer torch.Size([32, 64]) tensor([0.6214, 0.6180, 0.5934, 0.5447, 0.5653, 0.5324, 0.5887, 0.5464, 0.5800,
        0.5019], grad_fn=<SliceBackward>)
ChebyGINLayer torch.Size([1, 64]) tensor([0.5697], grad_fn=<SliceBackward>)
ChebyGINLayer torch.Size([64, 448]) tensor([0.5793, 0.5927, 0.5793, 0.5696, 0.5691, 0.5792, 0.5905, 0.5616, 0.5805,
        0.5668], grad_fn=<SliceBackward>)
ChebyGINLayer torch.Size([32, 128]) tensor([0.6048, 0.5640, 0.5743, 0.5179, 0.5856, 0.5736, 0.5848, 0.6068, 0.5857,
        0.5543], grad_fn=<SliceBackward>)
ChebyGINLayer torch.Size([32, 64]) tensor([0.5938, 0.5819, 0.5706, 0.5627, 0.6056, 0.5677, 0.5297, 0.5934, 0.5408,
        0.5715], grad