# installation

In [None]:
# Install torch geometric
import os
if 'IS_GRADESCOPE_ENV' not in os.environ:
  !pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu111.html
  !pip install torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+cu111.html
  !pip install torch-geometric
  !pip install -q git+https://github.com/snap-stanford/deepsnap.git
  !pip install -U -q PyDrive

In [None]:
if 'IS_GRADESCOPE_ENV' not in os.environ:
  !nvcc --version
  !python -c "import torch; print(torch.version.cuda)"

In [None]:
  import torch
  print(torch.__version__)
  import torch_geometric
  print(torch_geometric.__version__)

In [None]:
import copy
import torch
import deepsnap
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as pyg_nn

from sklearn.metrics import f1_score
from deepsnap.hetero_gnn import forward_op
from deepsnap.hetero_graph import HeteroGraph
from torch_sparse import SparseTensor, matmul

## Dataset
You need to login to your Google account and enter the verification code below.

In [None]:
if 'IS_GRADESCOPE_ENV' not in os.environ:
  from pydrive.auth import GoogleAuth
  from pydrive.drive import GoogleDrive
  from google.colab import auth
  from oauth2client.client import GoogleCredentials

  # Authenticate and create the PyDrive client
  auth.authenticate_user()
  gauth = GoogleAuth()
  gauth.credentials = GoogleCredentials.get_application_default()
  drive = GoogleDrive(gauth)

In [None]:
if 'IS_GRADESCOPE_ENV' not in os.environ:
  id='1ivlxd6lJMcZ9taS44TMGG72x2V1GeVvk'
  downloaded = drive.CreateFile({'id': id})
  downloaded.GetContentFile('acm.pkl')

# HeteroGNN model

## Implementing `HeteroGNNConv`

Now let's start working on our own implementation of the heterogeneous message passing layer (`HeteroGNNConv`)! Just as in Colabs 3 and 4, we will implement the layer using PyTorch Geometric. 

At a high level, the `HeteroGNNConv` layer is equivalent to the homogenous GNN layers we implemented in Colab 3, but now applied to an individual heterogeous message type. Moreover, our heterogeneous GNN layer draws directly from the **GraphSAGE** message passing model ([Hamilton et al. (2017)](https://arxiv.org/abs/1706.02216)).

We begin by defining the `HeteroGNNConv` layer with respect to message type $m$:

\begin{equation}
m =(s, r, d)
\end{equation}

where each message type is a tuple containing three elements: $s$ - the source node type, $r$ - the edge (relation) type, and $d$ - the destination node type. 

The message passing update rule that we implement is very similar to that of GraphSAGE, except we now need to include the node types and the edge relation type. The update rule for message type $m$ is described below:

\begin{equation}
h_v^{(l)[m]} = W^{(l)[m]} \cdot \text{CONCAT} \Big( W_d^{(l)[m]} \cdot h_v^{(l-1)}, W_s^{(l)[m]} \cdot AGG(\{h_u^{(l-1)}, \forall u \in N_{m}(v) \})\Big)
\end{equation}

where we compute $h_v^{(l)[m]}$, the node embedding representation for node $v$ after `HeteroGNNConv` layer $l$ with respect message type $m$. Further unpacking the forumla we have:
- $W_s^{(l)[m]}$ - linear transformation matrix for the messages of neighboring source nodes of type $s$ along message type $m$.
- $W_d^{(l)[m]}$ - linear transformation matrix for the message from the node $v$ itself of type $d$.
- $W^{(l)[m]}$ - linear transformation matrix for the concatenated messages from neighboring node's and the central node.
- $h_u^{(l-1)}$ - the hidden embedding representation for node $u$ after the $(l-1)$th `HeteroGNNWrapperConv` layer. Note, that this embedding is not associated with a particular message type (see layer diagrams above). 
- $N_{m}(v)$ - the set of neighbor source nodes $s$ for the node v that we are embedding along message type $m = (s, r, d)$. 

**NOTE**: We emphasize that each weight matrix is associated with a specific message type $[m]$ and additionally, the weight matrices applied to node messages are differentiated by node type (i.e. $W_s$ and $W_d$).

Lastly, for simplicity, we use mean aggregations for $AGG$ where:

\begin{equation}
AGG(\{h_u^{(l-1)}, \forall u \in N_{m}(v) \}) = \frac{1}{|N_{m}(v)|} \sum_{u\in N_{m}(v)} h_u^{(l-1)}
\end{equation}

In [None]:
class HeteroGNNConv(pyg_nn.MessagePassing):
    def __init__(self, in_channels_src, in_channels_dst, out_channels):
        super(HeteroGNNConv, self).__init__(aggr="mean")

        self.in_channels_src = in_channels_src
        self.in_channels_dst = in_channels_dst
        self.out_channels = out_channels

        # To simplify implementation, please initialize both self.lin_dst
        # and self.lin_src out_features to out_channels
        self.lin_dst = None
        self.lin_src = None

        self.lin_update = None

        self.lin_src = nn.Linear(self.in_channels_src, self.out_channels)  # W_s^{(l)[m]}
        self.lin_dst = nn.Linear(self.in_channels_dst, self.out_channels)  # W_d^{(l)[m]}
        self.lin_update = nn.Linear(self.out_channels * 2, self.out_channels)  # W^{(l)[m]}
        # input of concat(dst, src) is out_channels + out_channels, while desired output=out_channels

    def forward(
        self,
        node_feature_src,
        node_feature_dst,
        edge_index,  # sparse adj matrix (edge_index)
        size=None
    ):
        out = self.propagate(edge_index, size=size, node_feature_src=node_feature_src,
                             node_feature_dst=node_feature_dst)
        return out

    def message_and_aggregate(self, edge_index, node_feature_src):

        out = None
        ## Note:
        ## 1. Different from what we implemented in Colabs 3 and 4, we use message_and_aggregate
        ##    to combine the previously seperate message and aggregate functions. 
        ##    The benefit is that we can avoid materializing x_i and x_j
        ##    to make the implementation more efficient.
        ## 2. To implement efficiently, refer to PyG documentation for message_and_aggregate
        ##    and sparse-matrix multiplication:
        ##    https://pytorch-geometric.readthedocs.io/en/latest/notes/sparse_tensor.html
        ## 3. Here edge_index is torch_sparse SparseTensor. Although interesting, you
        ##    do not need to deeply understand SparseTensor represenations!
        ## 4. Conceptually, think through how the message passing and aggregation
        ##    expressed mathematically can be expressed through matrix multiplication.

        
        # edge_index sparse tensor
        # apply an efficient matrix multiplication between a sparse adjacency matrix (only {0,1} values)
        # and all the src node_type neighbors hidden representation (-> summation),
        # and aggregate to get the mean value
        out = matmul(edge_index, node_feature_src, reduce=self.aggr)  # 

        ##########################################

        return out

    def update(self, aggr_out, node_feature_dst):

        

        # aggr_out is the output of message_and_aggregate function 
        aggr_out = self.lin_src(aggr_out)
        dst_msg = self.lin_dst(node_feature_dst)
        msgs = torch.concat([dst_msg, aggr_out], axis=-1)
        # msgs = torch.concat([aggr_out, dst_msg], axis=-1)
        aggr_out = self.lin_update(msgs)
        # final output is h_v^{(l)[m]}

        return aggr_out

## Heterogeneous GNN Wrapper Layer

After implementing the `HeteroGNNConv` layer for each message type, we need to manage and aggregate the node embedding results (with respect to each message types). Here we will implement two types of message type level aggregation.

The first one is simply mean aggregation over message types:

\begin{equation}
h_v^{(l)} = \frac{1}{M}\sum_{m=1}^{M}h_v^{(l)[m]}
\end{equation}

where node $v$ has node type $d$ and we sum over the $M$ message types that have destination node type $d$. From our original example, for a node v of type $b$ we aggregate v's `HeteroGNNConv` embeddings for message types $m_2$ and $m_3$ (i.e. $h_v^{(l)[m_2]}$ and $h_v^{(l)[m_3]}$).

The second method we implement is the semantic level attention introduced in **HAN** ([Wang et al. (2019)](https://arxiv.org/abs/1903.07293)). Instead of directly averaging on the message type aggregation results, we use attention to learn which message type result is more important, then aggregate across all the message types. Below are the equations for semantic level attention:

\begin{equation}
e_{m} = \frac{1}{|V_{d}|} \sum_{v \in V_{d}} q_{attn}^T \cdot tanh \Big( W_{attn}^{(l)} \cdot h_v^{(l)[m]} + b \Big)
\end{equation}

where $m$ is the message type and $d$ refers to the destination node type for that message ($m = (s, r, d)$). Additionally, $V_{d}$ refers to the set of nodes v with type $d$. Lastly, the unormalized attention weight $e_m$ is a scaler computed for each message type $m$. 

Next, we can compute the normalized attention weights and update $h_v^{(l)}$:

\begin{equation}
\alpha_{m} = \frac{\exp(e_{m})}{\sum_{m=1}^M \exp(e_{m})}
\end{equation}

\begin{equation}
h_v^{(l)} = \sum_{m=1}^{M} \alpha_{m} \cdot h_v^{(l)[m]}
\end{equation}

, where we emphasize that $M$ here is the number of message types associated with the destination node type $d$. 

**Note**: The implementation of the attention aggregation is tricky and nuanced. We strongly recommend working carefully through the math equations to undersatnd exactly what each notation refers to and how all the pieces fit together. If you can, try to connect the math to our original example, focusing on node type $b$, which depends on two different message types!

**_We've implemented most of this for you but you'll need to initialize self.attn_proj in the initializer_**

In [None]:
class HeteroGNNWrapperConv(deepsnap.hetero_gnn.HeteroConv):
    def __init__(self, convs, args, aggr="mean"):
        super(HeteroGNNWrapperConv, self).__init__(convs, None)
        self.aggr = aggr

        # Map the index and message type
        self.mapping = {}

        # A numpy array that stores the final attention probability
        self.alpha = None

        self.attn_proj = None

        if self.aggr == "attn":
            # attn_proj:    q^T_{attn} * tanh(W^{(l)}_{attn} * h_v^{(l)[m]} + b_{attn})  (one element in e_m)
            self.attn_proj = nn.Sequential(
                nn.Linear(args['hidden_size'], args['attn_size']),
                nn.Tanh(),
                nn.Linear(args['attn_size'], 1, bias=False)
                )

    
    def reset_parameters(self):
        # reset parent parameters
        super(HeteroConvWrapper, self).reset_parameters()   
        # reset attention parameters
        if self.aggr == "attn":
            for layer in self.attn_proj.children():
                layer.reset_parameters()
    
    def forward(self, node_features, edge_indices):
        message_type_emb = {}
        # edge_indices: key= message, value=sparse adj. matrix (represents edge index)
        for message_key, message_type in edge_indices.items():
            src_type, edge_type, dst_type = message_key
            node_feature_src = node_features[src_type]
            node_feature_dst = node_features[dst_type]
            edge_index = edge_indices[message_key]
            message_type_emb[message_key] = (
                self.convs[message_key](
                    node_feature_src,
                    node_feature_dst,
                    edge_index,
                )
            )   # apply HeteroGNNConv for each message type
        node_emb = {dst: [] for _, _, dst in message_type_emb.keys()}  
        #node_emb- key=dst node type, value=list of embeddings of node type 'dst'
        mapping = {} 
        # mapping: key=#node embeddingd with node_type 'dst' , value=message       
        for (src, edge_type, dst), item in message_type_emb.items():
            mapping[len(node_emb[dst])] = (src, edge_type, dst)
            node_emb[dst].append(item)
        self.mapping = mapping
        for node_type, embs in node_emb.items():
            if len(embs) == 1:   # embs- list of embeddings
                node_emb[node_type] = embs[0]
            else:
                node_emb[node_type] = self.aggregate(embs)
        return node_emb
    
    def aggregate(self, xs):
        # TODO: Implement this function that aggregates all message type results.
        # Here, xs is a list of tensors (embeddings) with respect to message 
        # type aggregation results.

        if self.aggr == "mean":

            x = torch.stack(xs, dim=-1)  # concatenate list of tensor over last (new) dimension
            return x.mean(dim=-1)  # aggregate over the new dim


            ##########################################

        elif self.aggr == "attn":
            N = xs[0].shape[0] # Number of nodes for that node type
            M = len(xs) # Number of message types for that node type

            x = torch.cat(xs, dim=0).view(M, N, -1) # M * N * D
            z = self.attn_proj(x).view(M, N) # M * N * 1
            z = z.mean(1) # M * 1
            alpha = torch.softmax(z, dim=0) # M * 1

            # Store the attention result to self.alpha as np array
            self.alpha = alpha.view(-1).data.cpu().numpy()
  
            alpha = alpha.view(M, 1, 1)
            x = x * alpha
            return x.sum(dim=0)

## Initialize Heterogeneous GNN Layers

Now let's put it all together and initialize the Heterogeneous GNN Layers. Different from the homogeneous graph case, heterogeneous graphs can be a little bit complex.

In general, we need to create a dictionary of `HeteroGNNConv` layers where the keys are message types.

* To get all message types, `deepsnap.hetero_graph.HeteroGraph.message_types` is useful.
* If we are initializing the first conv layers, we need to get the feature dimension of each node type. Using `deepsnap.hetero_graph.HeteroGraph.num_node_features(node_type)` will return the node feature dimension of `node_type`. In this function, we will set each `HeteroGNNConv` `out_channels` to be `hidden_size`.
* If we are not initializing the first conv layers, all node types will have the same embedding dimension `hidden_size` and we still set `HeteroGNNConv` `out_channels` to be `hidden_size` for simplicity.



In [None]:
def generate_convs(hetero_graph, conv, hidden_size, first_layer=False):
    # TODO: Implement this function that returns a dictionary of `HeteroGNNConv` 
    # layers where the keys are message types. `hetero_graph` is deepsnap `HeteroGraph`
    # object and the `conv` is the `HeteroGNNConv`.

    convs = {}  # a dictionary with key=msg_tpye, value=HeteroGNNConv layer
    msg_types = hetero_graph.message_types  # get list of all messages
    
    for msg_type in msg_types:
        in_channels_src = in_channels_dst = hidden_size
        if first_layer:  # extract input size of src node tpye and dst node_type
            src_type = msg_type[0]
            dst_type = msg_type[-1]  # or 2
            in_channels_src = hetero_graph.num_node_features(src_type)
            in_channels_dst = hetero_graph.num_node_features(dst_type)
        convs[msg_type] = conv(in_channels_src, in_channels_dst, hidden_size)
    
    return convs

In [None]:
class HeteroGNN(torch.nn.Module):
    def __init__(self, hetero_graph, args, aggr="mean"):
        super(HeteroGNN, self).__init__()

        self.aggr = aggr
        self.hidden_size = args['hidden_size']

        self.convs1 = None
        self.convs2 = None

        self.bns1 = nn.ModuleDict()
        self.bns2 = nn.ModuleDict()
        self.relus1 = nn.ModuleDict()
        self.relus2 = nn.ModuleDict()
        self.post_mps = nn.ModuleDict()

        convs1 = generate_convs(hetero_graph, HeteroGNNConv, self.hidden_size, first_layer=True)
        self.convs1 = HeteroGNNWrapperConv(convs1, args, aggr=self.aggr)

        convs2 = generate_convs(hetero_graph, HeteroGNNConv, self.hidden_size, first_layer=False)
        self.convs2 = HeteroGNNWrapperConv(convs2, args, aggr=self.aggr)

        for node_type in hetero_graph.node_types:
            self.bns1[node_type] = nn.BatchNorm1d(self.hidden_size, eps=args['eps'])
            self.bns2[node_type] = nn.BatchNorm1d(self.hidden_size, eps=args['eps'])
            self.relus1[node_type] = nn.LeakyReLU()
            self.relus2[node_type] = nn.LeakyReLU()
            self.post_mps[node_type] = nn.Linear(self.hidden_size, hetero_graph.num_node_labels(node_type))
      
        ##########################################

    def forward(self, node_feature, edge_index):
        # TODO: Implement the forward function. Notice that `node_feature` is 
        # a dictionary of tensors where keys are node types and values are 
        # corresponding feature tensors. The `edge_index` is a dictionary of 
        # tensors where keys are message types and values are corresponding
        # edge index tensors (with respect to each message type).

        # forward_op: apply forward pass on dictionaries:
            # input:
            #   * x- a dictionary with key=node_type, value=tensor of correspondnig nodes' representation
            #   * mouledict- a dictionary with key=node_type, value=a module (e.g.glinear/relu/bn layer)


        x = node_feature

        x = self.convs1(x, edge_index)
        x = forward_op(x, self.bns1)   
        x = forward_op(x, self.relus1)
        x = self.convs2(x, edge_index)
        x = forward_op(x, self.bns2)
        x = forward_op(x, self.relus2)
        x = forward_op(x, self.post_mps)
        
        return x

    def loss(self, preds, y, indices):
        
        loss = 0
        loss_func = F.cross_entropy

        for node_type, labeled_idxes in indices.items():
        # for each node type take the relevant (supervied) indices of true labels and predictions
            y_type = y[node_type][labeled_idxes]
            preds_type = preds[node_type][labeled_idxes]
            loss += loss_func(preds_type, y_type)
        
        return loss

## Training and Testing

In [None]:
import pandas as pd

def train(model, optimizer, hetero_graph, train_idx):
    model.train()
    optimizer.zero_grad()
    preds = model(hetero_graph.node_feature, hetero_graph.edge_index)

    y = hetero_graph.node_label  # a dictionary (node_type, tensor of corresponding nodes' label)
    loss = model.loss(preds, y, train_idx)

    loss.backward()
    optimizer.step()
    return loss.item()

def test(model, graph, indices, best_model=None, best_val=0, save_preds=False, agg_type=None):
    model.eval()
    accs = []
    for i, index in enumerate(indices):
        preds = model(graph.node_feature, graph.edge_index)
        num_node_types = 0
        micro = 0
        macro = 0
        for node_type in preds:
            idx = index[node_type]
            pred = preds[node_type][idx]
            pred = pred.max(1)[1]  # take the indexes of maximum values (by columns)
            label_np = graph.node_label[node_type][idx].cpu().numpy()
            pred_np = pred.cpu().numpy()
            micro = f1_score(label_np, pred_np, average='micro')
            macro = f1_score(label_np, pred_np, average='macro')
            num_node_types += 1
                  
        # Averaging f1 score might not make sense, but in our example we only
        # have one node type
        micro /= num_node_types
        macro /= num_node_types
        accs.append((micro, macro))

        # Only save the test set predictions and labels!
        if save_preds and i == 2:
          print ("Saving Heterogeneous Node Prediction Model Predictions with Agg:", agg_type)
          print()

          data = {}
          data['pred'] = pred_np
          data['label'] = label_np

          df = pd.DataFrame(data=data)
          # Save locally as csv
          df.to_csv('ACM-Node-' + agg_type + 'Agg.csv', sep=',', index=False)

    if accs[1][0] > best_val:
        best_val = accs[1][0]
        best_model = copy.deepcopy(model)
    return accs, best_model, best_val

## Dataset and Preprocessing

we will load the data and create a tensor backend (without a NetworkX graph) `deepsnap.hetero_graph.HeteroGraph` object.

`ACM(3025)` dataset for node property prediction task, which is proposed in **HAN** ([Wang et al. (2019)](https://arxiv.org/abs/1903.07293)) and our dataset is extracted from [DGL](https://www.dgl.ai/)'s [ACM.mat](https://data.dgl.ai/dataset/ACM.mat).

The original ACM dataset has three node types and two edge (relation) types. For simplicity, we simplify the heterogeneous graph to one node type and two edge types (shown below). This means that in our heterogeneous graph, we have one node type (paper) and two message types (paper, author, paper) and (paper, subject, paper).

<br/>
<center>
<img src="http://web.stanford.edu/class/cs224w/images/colab4/cs224w-acm.png"/>
</center>

In [None]:
args = {
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'hidden_size': 64,
    'epochs': 100,
    'weight_decay': 1e-5,
    'lr': 0.003,
    'attn_size': 32,
    'eps': 1.0,
}

In [None]:
if 'IS_GRADESCOPE_ENV' not in os.environ:
  print("Device: {}".format(args['device']))

  # Load the data
  data = torch.load("acm.pkl")

#   data.keys: dict_keys(['label', 'feature', 'pap', 'psp', 'train_idx', 'val_idx', 'test_idx'])

  # Message types
  message_type_1 = ("paper", "author", "paper")
  message_type_2 = ("paper", "subject", "paper")

  # Dictionary of edge indices
  edge_index = {}
  edge_index[message_type_1] = data['pap']  # 2 X |E|
  edge_index[message_type_2] = data['psp']  

  # Dictionary of node features
  node_feature = {}
  node_feature["paper"] = data['feature']   # |V| X D_node

  # Dictionary of node labels
  node_label = {}
  node_label["paper"] = data['label']  #  |V|

  # Load the train, validation and test indices
  train_idx = {"paper": data['train_idx'].to(args['device'])}
  val_idx = {"paper": data['val_idx'].to(args['device'])}
  test_idx = {"paper": data['test_idx'].to(args['device'])}

  # Construct a deepsnap tensor backend HeteroGraph
  hetero_graph = HeteroGraph(
      node_feature=node_feature,
      node_label=node_label,
      edge_index=edge_index,
      directed=True
  )

  print(f"ACM heterogeneous graph: {hetero_graph.num_nodes()} nodes, {hetero_graph.num_edges()} edges")

  # Node feature and node label to device
  for key in hetero_graph.node_feature:
      hetero_graph.node_feature[key] = hetero_graph.node_feature[key].to(args['device'])
  for key in hetero_graph.node_label:
      hetero_graph.node_label[key] = hetero_graph.node_label[key].to(args['device'])

  # Edge_index to sparse tensor (sparse adjacency matrix) and to device
  for key in hetero_graph.edge_index:
      edge_index = hetero_graph.edge_index[key]
      adj = SparseTensor(row=edge_index[0], col=edge_index[1], sparse_sizes=(hetero_graph.num_nodes('paper'), hetero_graph.num_nodes('paper')))
      hetero_graph.edge_index[key] = adj.t().to(args['device'])
  print(hetero_graph.edge_index[message_type_1])
  print(hetero_graph.edge_index[message_type_2])


# hetero_graph.edge_index- a dictionary:
# * key=message (src node_type, relation, dst node_type)
# * value=sparseTensor of edge_index
  

In [None]:
if 'IS_GRADESCOPE_ENV' not in os.environ:
  best_model = None
  best_val = 0

  model = HeteroGNN(hetero_graph, args, aggr="attn").to(args['device'])  # model with attention
#   model = HeteroGNN(hetero_graph, args, aggr="mean").to(args['device'])  # model with mean aggr
  optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])

  for epoch in range(args['epochs']):
      loss = train(model, optimizer, hetero_graph, train_idx)
      accs, best_model, best_val = test(model, hetero_graph, [train_idx, val_idx, test_idx], best_model, best_val)
      
      print(
          f"Epoch {epoch + 1}: loss {round(loss, 5)}, "
          f"train micro {round(accs[0][0] * 100, 2)}%, train macro {round(accs[0][1] * 100, 2)}%, "
          f"valid micro {round(accs[1][0] * 100, 2)}%, valid macro {round(accs[1][1] * 100, 2)}%, "
          f"test micro {round(accs[2][0] * 100, 2)}%, test macro {round(accs[2][1] * 100, 2)}%"
      )

  best_accs, _, _ = test(best_model, hetero_graph, [train_idx, val_idx, test_idx], save_preds=True, agg_type="Attention")
#   best_accs, _, _ = test(best_model, hetero_graph, [train_idx, val_idx, test_idx], save_preds=True, agg_type="Mean")
  print(
      f"Best model: "
      f"train micro {round(best_accs[0][0] * 100, 2)}%, train macro {round(best_accs[0][1] * 100, 2)}%, "
      f"valid micro {round(best_accs[1][0] * 100, 2)}%, valid macro {round(best_accs[1][1] * 100, 2)}%, "
      f"test micro {round(best_accs[2][0] * 100, 2)}%, test macro {round(best_accs[2][1] * 100, 2)}%"
  )

### Attention for each Message Type

Through message type level attention we can learn which message type is more important to which layer.

Here we will print out and show that each layer pay how much attention on each message type.

In [None]:
if 'IS_GRADESCOPE_ENV' not in os.environ:
  if model.convs1.alpha is not None and model.convs2.alpha is not None:
      for idx, message_type in model.convs1.mapping.items():
          print(f"Layer 1 has attention {model.convs1.alpha[idx]} on message type {message_type}")
      for idx, message_type in model.convs2.mapping.items():
          print(f"Layer 2 has attention {model.convs2.alpha[idx]} on message type {message_type}")