In [3]:
import torch

In [13]:
import torch.nn.functional as F
from torch.autograd import grad


class Grad_CAM():
    def __init__(self, model):
        self.model = model
        self.activations = None
        self.gradient = None

        # register hooks to capture the feature_map gradients
        def forward_hook(model, input, output):
            self.activations = output[0]

        def backward_hook(model, grad_input, grad_output):
            self.gradient = grad_output[0][0]

        feat_map = model.features[29]  # Relu layer
        feat_map.register_forward_hook(forward_hook)
        feat_map.register_backward_hook(backward_hook)

    def get_grad_cam(self, img, indices=None):
        self.model.eval()
        out = self.model(img)  # 1*1000
        num_features = self.activations.size()[0]  # 512 *14*14
        topk = 3
        if indices == None:
            values, indices = torch.topk(out, topk)
        else:
            values = torch.tensor([np.array(range(4, 1, -1))])
            indices = torch.tensor([indices])
        # Compute 14x14 heatmaps
        heatmaps = torch.zeros(topk, 14, 14)
        for i, c in enumerate(indices[0]):
            self.model.zero_grad()
            out[0, c].backward(retain_graph=True)  # 512 *14*14
            # feature importance
            feature_importance = self.gradient.mean(dim=[1, 2])  # 512
            # pixel importance
            for f in range(num_features):
                heatmaps[i] += feature_importance[f] * \
                    self.activations[f]  # int * [14*14]    512*14*14
            heatmaps[i] = F.relu(heatmaps[i])
            heatmaps[i] /= torch.max(heatmaps[i])
#             print(heatmaps[i].shape,feature_importance.shape , self.activations.shape, self.gradient.shape)
        # Upsample to 224x224
        large_heatmaps = F.interpolate(heatmaps.expand(
            (1, topk, 14, 14)), (224, 224), mode='bilinear')
        return large_heatmaps[0].data.numpy(), values.data.numpy()[0], indices.data.numpy()[0]


In [20]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Batch





class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.num_layers = 2
        self.dropout = 0.5
        self.hidden_size = 64
        self.num_classes = 1
        self.node_sz = 400

        self.convs = torch.nn.ModuleList()
        for i in range(self.num_layers):
            if i == 0:
                self.convs.append(
                    GCNConv(400, self.hidden_size))
            else:
                self.convs.append(GCNConv(self.hidden_size, self.hidden_size))

        self.readout_lin = nn.Linear(
            self.node_sz * self.hidden_size, self.hidden_size)

        self.lin = nn.Linear(self.hidden_size, 1)

        self.relu = nn.ReLU()

    def convert_edge_positive(self, edge_index, edge_weight):
        edge_index = edge_index[:, edge_weight > 0]
        edge_weight = edge_weight[edge_weight > 0]
        return edge_index, edge_weight

    def forward(self, data, **kwargs):
        # self.epoch = kwargs['epoch']
        # self.iteration = kwargs['iteration']
        # self.test_phase = kwargs['test_phase']

        x, edge_index, edge_weight, batch, labels = data.x, data.edge_index, data.edge_weight, data.batch, data.y
        edge_index, edge_weight = self.convert_edge_positive(
            edge_index, edge_weight)
        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index, edge_weight)
            if i < self.num_layers - 1:
                x = F.leaky_relu(x)

        xs = []
        for graph_idx in batch.unique():
            graph_nodes = x[batch == graph_idx]
            graph_nodes = graph_nodes.view(-1)
            xs.append(self.readout_lin(graph_nodes))
        x = torch.stack(xs).to(x.device)

        x = F.leaky_relu(x)

        x = self.lin(x)

        return x


In [None]:
class Attn_Net_Gated(nn.Module):
    # Attention Network with Sigmoid Gating (3 fc layers). Args:
    # L: input feature dimension
    # D: hidden layer dimension
    # dropout: whether to use dropout (p = 0.25)
    # n_classes: number of classes """

    def __init__(self, L=64, D=256, dropout=True, n_classes=1):
        super(Attn_Net_Gated, self).__init__()
        self.attention_a = [nn.Linear(L, D), nn.Tanh()]
        self.attention_b = [nn.Linear(L, D), nn.Sigmoid()]
        if dropout:
            self.attention_a.append(nn.Dropout(0.25))
            self.attention_b.append(nn.Dropout(0.25))

        self.attention_a = nn.Sequential(*self.attention_a)
        self.attention_b = nn.Sequential(*self.attention_b)
        self.attention_c = nn.Linear(D, n_classes)

    def forward(self, x):

        a = self.attention_a(x)
        b = self.attention_b(x)
        A = a.mul(b)
        A = self.attention_c(A)  # N x n_classes
        # A = F.softmax(A, dim=0)
        return A, x


In [None]:
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
# from torch_geometric.nn import global_mean_pool, HypergraphConv
from torch_geometric.data import Batch

from omegaconf import DictConfig


import ipdb


class Attn_Net_Gated(nn.Module):
    # Attention Network with Sigmoid Gating (3 fc layers). Args:
    # L: input feature dimension
    # D: hidden layer dimension
    # dropout: whether to use dropout (p = 0.25)
    # n_classes: number of classes """

    def __init__(self, L=64, D=256, dropout=True, n_classes=1):
        super(Attn_Net_Gated, self).__init__()
        self.attention_a = [nn.Linear(L, D), nn.Tanh()]
        self.attention_b = [nn.Linear(L, D), nn.Sigmoid()]
        if dropout:
            self.attention_a.append(nn.Dropout(0.25))
            self.attention_b.append(nn.Dropout(0.25))

        self.attention_a = nn.Sequential(*self.attention_a)
        self.attention_b = nn.Sequential(*self.attention_b)
        self.attention_c = nn.Linear(D, n_classes)

    def forward(self, x):

        a = self.attention_a(x)
        b = self.attention_b(x)
        A = a.mul(b)
        A = self.attention_c(A)  # N x n_classes
        # A = F.softmax(A, dim=0)
        return A, x


class DwHGN(torch.nn.Module):
    def __init__(self):
        super(DwHGN, self).__init__()

        self.num_layers = 1
        self.dropout = 0.5
        self.hidden_size = 64
        self.num_classes = 2
        self.node_sz = 400

        self.num_edges = 400

        self.convs = torch.nn.ModuleList()
        for i in range(self.num_layers):
            if i == 0:
                self.convs.append(DwHGNConv(
                    self.node_sz, self.hidden_size, num_edges=self.num_edges))
            else:
                self.convs.append(DwHGNConv(
                    self.node_sz, self.hidden_size, self.hidden_size, num_edges=self.num_edges))

        # if self.cfg.model.readout == 'set_transformer':
        #     self.readout_layer = SetTransformer(dim_input=self.hidden_size,
        #                                         num_outputs=1, dim_output=self.hidden_size)
        # elif self.cfg.model.readout == 'janossy':
        #     self.readout_layer = JanossyPooling(
        #         num_perm=cfg.model.num_perm, in_features=self.hidden_size, fc_out_features=self.hidden_size)
        self.readout_lin = nn.Linear(
            self.node_sz * self.hidden_size, self.hidden_size)

        self.lin = nn.Linear(self.hidden_size, 1)

        # interpretability
        # if self.cfg.model.node_attn_interpret:
        self.attn_gated = Attn_Net_Gated(L=self.hidden_size)

    def forward(self, data, **kwargs):
        self.epoch = kwargs['epoch']
        self.iteration = kwargs['iteration']
        self.test_phase = kwargs['test_phase']
        x, hyperedge_index, hyperedge_weight, batch, labels = data.x, data.edge_index, data.edge_weight, data.batch, data.y
        for i in range(self.num_layers):
            # x = self.convs[i](x, hyperedge_index, hyperedge_weight, self.num_edges)
            x = self.convs[i](x, hyperedge_index, epoch=self.epoch)

            if i < self.num_layers:
                x = F.leaky_relu(x)

        # node_attn (interpretability)
        if self.cfg.model.node_attn_interpret:
            xs = []
            saved_A = []
            for graph_idx in batch.unique():
                graph_nodes = x[batch == graph_idx]
                A, x_new = self.attn_gated(graph_nodes)
                saved_A.append(A.view(-1))
                # Broadcasting A to the same dimensions as x
                A_broadcasted = A.expand_as(x_new)
                # Performing element-wise multiplication
                if self.cfg.model.node_attn_learn:
                    x_new = A_broadcasted * x_new
                xs.append(x_new)
            x = torch.stack(xs).to(x.device)
            x = x.view(-1, self.hidden_size)

            if self.cfg.model.node_attn_save:
                saved_A = torch.stack(saved_A)
                node_att_data_save(saved_A, self.epoch,
                                   self.iteration, labels, train=True)

        # if self.cfg.model.readout in ['set_transformer', 'janossy']:
        #     x = x.view(-1, self.node_sz, self.hidden_size)
        #     x = self.readout_layer(x)
        #     x = x.squeeze()
        # else:
        xs = []
        for graph_idx in batch.unique():
            graph_nodes = x[batch == graph_idx]
            graph_nodes = graph_nodes.view(-1)
            xs.append(self.readout_lin(graph_nodes))
        x = torch.stack(xs).to(x.device)
        # if kwargs['test_phase'] and self.cfg.model.tsne:
        #     tsne_plot_data(x, labels, self.epoch, self.iteration)

        # if self.cfg.model.tsne:
        #     if kwargs['test_phase']:
        #         tsne_plot_data(x, labels, self.epoch, self.iteration)
        #     elif self.cfg.model.tsne_train:
        #         tsne_plot_data(x, labels, self.epoch,
        #                        self.iteration, train=True)

        x = F.leaky_relu(x)
        x = self.lin(x)

        return x


In [21]:
modelx = GCN()
modelx.load_state_dict(torch.load(
    '/home/mehul/asd_graph/baselines/outputs/2023-09-12/18-46-50/best_model.pt'))


<All keys matched successfully>

In [22]:
modelx


GCN(
  (convs): ModuleList(
    (0): GCNConv(400, 64)
    (1): GCNConv(64, 64)
  )
  (readout_lin): Linear(in_features=25600, out_features=64, bias=True)
  (lin): Linear(in_features=64, out_features=1, bias=True)
)

In [11]:
# '/home/mehul/asd_graph/baselines/outputs/2023-09-12/18-59-42/best_model.pt'
import torch
saved_weights_path = '/home/mehul/asd_graph/baselines/outputs/2023-09-12/19-07-24/best_model.pt'
saved_state_dict = torch.load(saved_weights_path)


In [12]:
saved_state_dict.keys()


odict_keys(['convs.0.learned_he_weights', 'convs.0.bias', 'convs.0.lin.weight', 'readout_lin.weight', 'readout_lin.bias', 'lin.weight', 'lin.bias', 'attn_gated.attention_a.0.weight', 'attn_gated.attention_a.0.bias', 'attn_gated.attention_b.0.weight', 'attn_gated.attention_b.0.bias', 'attn_gated.attention_c.weight', 'attn_gated.attention_c.bias'])

In [13]:
for key, val in saved_state_dict.items():
    print(key, ":", val.shape)

convs.0.learned_he_weights : torch.Size([400])
convs.0.bias : torch.Size([64])
convs.0.lin.weight : torch.Size([64, 400])
readout_lin.weight : torch.Size([64, 25600])
readout_lin.bias : torch.Size([64])
lin.weight : torch.Size([1, 64])
lin.bias : torch.Size([1])
attn_gated.attention_a.0.weight : torch.Size([256, 64])
attn_gated.attention_a.0.bias : torch.Size([256])
attn_gated.attention_b.0.weight : torch.Size([256, 64])
attn_gated.attention_b.0.bias : torch.Size([256])
attn_gated.attention_c.weight : torch.Size([1, 256])
attn_gated.attention_c.bias : torch.Size([1])


In [16]:
saved_state_dict['convs.0.lin.weight']


tensor([[-0.0583, -0.0433,  0.0879,  ...,  0.0387, -0.0656,  0.0341],
        [ 0.0040,  0.0504,  0.0245,  ..., -0.0366, -0.0294,  0.1019],
        [-0.0304,  0.0243, -0.0381,  ..., -0.1051,  0.0069,  0.0738],
        ...,
        [-0.0790, -0.0708, -0.0131,  ...,  0.1116, -0.0976, -0.0517],
        [ 0.0122, -0.0529, -0.0747,  ..., -0.0019,  0.0374, -0.0132],
        [-0.0894, -0.0129,  0.0185,  ..., -0.0135,  0.0117,  0.1099]],
       device='cuda:0')