In [1]:
%load_ext autoreload
%autoreload 2

In [11]:
from typing import Tuple

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.express as px

import torch
from torch_geometric.data import Data
import torch_geometric.nn

import statsmodels.api as sm
import statsmodels.formula.api as smf

from examples.introduction.my_GAT_implem import GNN_naive_framework

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


### Data Loading

In [3]:
data = pd.read_excel("data/paired_data_newSim.xlsx")

In [4]:
_temp = data["word_pair"].str.split(".")
data["word1"] = _temp.apply(func=lambda x: x[0])
data["word2"] = _temp.apply(func=lambda x: x[1][1:])

data.head()

Unnamed: 0,word_pair,rated_similarity,abs_liking_diference,word1_liking,word2_liking,word1_experience,word2_experience,depression,depressionCont,female,age,participant,senenceBERT_mpnet_similarity,senenceBERT_miniLM_similarity,sense2vec_similarity,gptLarge_similarity,word1,word2
0,Art gallery. Autobiography book.,,0,21,21,6,8,0,12,1,29,1,0.375817,0.275882,0.337977,0.6715,Art gallery,Autobiography book
1,Art gallery. Baking cookies.,,59,21,80,6,14,0,12,1,29,1,0.246449,0.14693,0.209372,0.593072,Art gallery,Baking cookies
2,Art gallery. Board games.,,57,21,78,6,94,0,12,1,29,1,0.347372,0.224889,0.22729,0.726036,Art gallery,Board games
3,Art gallery. Book club.,52.0,38,21,59,6,0,0,12,1,29,1,0.390099,0.335998,0.307647,0.792389,Art gallery,Book club
4,Art gallery. Bread making.,,54,21,75,6,16,0,12,1,29,1,0.27084,0.197813,0.225336,0.623973,Art gallery,Bread making


The edges of the word pairs are labelled in `word_pair`. We have the `rated_similarity` variable given by the participant with id stored in `participant`. Some similarity measures are given from different LLMs and stored under:
- `senenceBERT_mpnet_similarity`,
- `senenceBERT_miniLM_similarity`,
- `sense2vec_similarity`,
- `gptLarge_similarity`.

For the edges, added to these similarity measures, we have the `abs_liking_difference` computed from `word1_liking` and `word2_liking`. The experience for the two words are stored in `word1_experience` and `word2_experience`. 

Depending directly on the participant, we have his/her `age`, gender under `female` (1 for female), and `depression` and `depressionCont` scores. 

# Graph

In [41]:
from sklearn.preprocessing import StandardScaler

def nor_function(a,b):
    return (a or b) and not(a and b)

data["NoExp_Exp"] = data.apply(lambda row: nor_function(row["word1_experience"]>50,row["word2_experience"]>50),axis=1)

scaler = StandardScaler()
data.loc[:,["word1_sc_liking","word2_sc_liking","sc_senenceBERT_mpnet_similarity","sc_depressionCont","sc_NoExp_Exp"]] = scaler.fit_transform(data.loc[:,["word1_liking","word2_liking","senenceBERT_mpnet_similarity","depressionCont","NoExp_Exp"]])
data.loc[:,["word1_sc_liking","word2_sc_liking","sc_senenceBERT_mpnet_similarity","sc_depressionCont","sc_NoExp_Exp"]]

Unnamed: 0,word1_sc_liking,word2_sc_liking,sc_senenceBERT_mpnet_similarity,sc_depressionCont,sc_NoExp_Exp
0,-1.172222,-1.152427,1.118413,-0.229448,-0.891164
1,-1.172222,0.605846,-0.021906,-0.229448,-0.891164
2,-1.172222,0.546244,0.867681,-0.229448,1.122128
3,-1.172222,-0.019980,1.244298,-0.229448,-0.891164
4,-1.172222,0.456840,0.193088,-0.229448,-0.891164
...,...,...,...,...,...
198235,0.456083,0.546244,0.040509,-0.708297,1.122128
198236,0.456083,1.201871,-0.668605,-0.708297,1.122128
198237,0.394638,0.546244,-1.800021,-0.708297,1.122128
198238,0.394638,1.201871,-0.165669,-0.708297,1.122128


In [43]:
subdata = data[data["participant"] == 1]


participant_graph = convert_table_to_graph(
    complete_data_table=subdata,
    node_attr_names=["sc_liking"],
    node_label_names=["sc_liking"],
    edge_attr_names=["senenceBERT_mpnet_similarity"])

participant_graph.x

Test function convert_table_to_graph
Data(x=[60, 1], edge_index=[2, 3540], edge_attr=[3540, 1], y=[60, 1], train_mask=[60], val_mask=[60])
validate: True
is undirected: True
has_self_loop: tensor(False)
end Test function convert_table_to_graph


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  extracted_features_word1.rename(columns=col_renaming,inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  extracted_features_word2.rename(columns=col_renaming,inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  complete_data_table["word1_index"] = complete_data_table["word1"].apply(lambda single_word: translater_word_to_index[single_word])
A value is trying to be set on a copy of a slice from 

tensor([[-1.1722],
        [-1.1722],
        [ 0.6404],
        [ 0.5790],
        [-0.0048],
        [ 0.4868],
        [ 0.8862],
        [ 0.9476],
        [ 0.8862],
        [ 0.7940],
        [-1.6023],
        [-0.9264],
        [ 0.5483],
        [-0.1891],
        [ 0.9169],
        [-0.1891],
        [ 0.5483],
        [-0.6499],
        [-0.2813],
        [ 0.8555],
        [ 0.7326],
        [ 0.7633],
        [ 1.2549],
        [ 1.2549],
        [ 1.0705],
        [-0.2813],
        [-1.2029],
        [ 1.0705],
        [-1.4180],
        [ 0.9476],
        [-1.2951],
        [ 0.5483],
        [-0.1891],
        [ 0.7326],
        [ 0.5175],
        [ 1.2549],
        [ 0.7326],
        [-0.7421],
        [-0.0969],
        [-1.0801],
        [-1.1415],
        [ 0.9476],
        [-0.9264],
        [ 0.6097],
        [ 1.2549],
        [-1.6331],
        [ 0.0874],
        [-0.4042],
        [-1.2951],
        [ 1.0091],
        [-0.4042],
        [-0.5885],
        [-1.

In [None]:
my_module = torch_geometric.nn.GATConv(
    in_channels=(1,1),
    out_channels=1,
    heads=1,
    negative_slope=1.0,
    add_self_loops=False,
    edge_dim=1)
# my_module(x=participant_graph.x,edge_index=participant_graph.edge_index,edge_attr=participant_graph.edge_attr)

complete_model = GNN_naive_framework(my_module,device)
opt = complete_model.configure_optimizer(lr=1)
scheduler = complete_model.configure_scheduler(opt,1,1,10)

participant_graph_batch_0 = Data(x=participant_graph.x[:30],
                                 y=participant_graph.x[:30], 
                                 train_mask = torch.ones(30), 
                                 edge_index=torch.Tensor([[]]),
                                 edge_attr=torch.Tensor([[]]))

participant_graph_batch_1 = Data(x=participant_graph.x[30:],
                                 y=participant_graph.x[30:], 
                                 train_mask = torch.ones(30), 
                                 edge_index=torch.Tensor([[]]),
                                 edge_attr=torch.Tensor([[]]))

complete_model.train([participant_graph_batch_0,participant_graph_batch_1],10000,1,opt,scheduler,"train_loss",100)

In [None]:
class MLPModel(torch.nn.Module):
    def __init__(self, c_in, c_hidden, c_out, num_layers=2, dp_rate=0.1):
        """MLPModel.

        Args:
            c_in: Dimension of input features
            c_hidden: Dimension of hidden features
            c_out: Dimension of the output features. Usually number of classes in classification
            num_layers: Number of hidden layers
            dp_rate: Dropout rate to apply throughout the network

        """
        super().__init__()
        layers = []
        in_channels, out_channels = c_in, c_hidden
        for l_idx in range(num_layers - 1):
            layers += [torch.nn.Linear(in_channels, out_channels), torch.nn.Sigmoid(), torch.nn.Dropout(dp_rate)]
            in_channels = c_hidden
        layers += [torch.nn.Linear(in_channels, c_out)]
        self.layers = torch.nn.Sequential(*layers)

    def forward(self, x, *args, **kwargs):
        """Forward.

        Args:
            x: Input features per node

        """
        return self.layers(x)

my_module = MLPModel(c_in=1, c_hidden=1, c_out=1,num_layers=2,dp_rate=0.0)

print([param for param in my_module.parameters()])

complete_model = GNN_naive_framework(my_module,device)
opt = complete_model.configure_optimizer(lr=1)
scheduler = complete_model.configure_scheduler(opt,1,1,10)

history = complete_model.train([participant_graph],10000,1,opt,scheduler,"train_loss",100)

[Parameter containing:
tensor([[-0.8243]], requires_grad=True), Parameter containing:
tensor([-0.2371], requires_grad=True), Parameter containing:
tensor([[0.5436]], requires_grad=True), Parameter containing:
tensor([0.8273], requires_grad=True)]
== start training ==
batch_graph: Data(x=[60, 1], edge_index=[2, 3540], edge_attr=[3540, 1], y=[60, 1], train_mask=[60], val_mask=[60])
preds: tensor([[1.1941],
        [1.1941],
        [1.0000],
        [1.0060],
        [1.0676],
        [1.0152],
        [0.9770],
        [0.9716],
        [0.9770],
        [0.9854],
        [1.2336],
        [1.1691],
        [1.0091],
        [1.0881],
        [0.9743],
        [1.0881],
        [1.0091],
        [1.1395],
        [1.0984],
        [0.9798],
        [0.9912],
        [0.9883],
        [0.9464],
        [0.9464],
        [0.9611],
        [1.0984],
        [1.1971],
        [0.9611],
        [1.2174],
        [0.9716],
        [1.2060],
        [1.0091],
        [1.0881],
        [0.9912]

In [73]:
[param for param in my_module.parameters()]

[Parameter containing:
 tensor([[ 2.5801],
         [-1.2462],
         [-1.2404],
         [-0.3156],
         [-3.9072]], device='cuda:0', requires_grad=True),
 Parameter containing:
 tensor([-8.2896,  0.5276, -9.6809, -8.8769, -4.8124], device='cuda:0',
        requires_grad=True),
 Parameter containing:
 tensor([[ 3.5628, -3.3370,  2.6040,  5.0165, -0.7148]], device='cuda:0',
        requires_grad=True),
 Parameter containing:
 tensor([2.0711], device='cuda:0', requires_grad=True)]

In [74]:
my_module.forward(torch.Tensor([[[1]]]).to(device))

tensor([[[0.9899]]], device='cuda:0', grad_fn=<ViewBackward0>)

In [75]:
print(participant_graph.y)

print(participant_graph.x)

tensor([[-1.1722],
        [-1.1722],
        [ 0.6404],
        [ 0.5790],
        [-0.0048],
        [ 0.4868],
        [ 0.8862],
        [ 0.9476],
        [ 0.8862],
        [ 0.7940],
        [-1.6023],
        [-0.9264],
        [ 0.5483],
        [-0.1891],
        [ 0.9169],
        [-0.1891],
        [ 0.5483],
        [-0.6499],
        [-0.2813],
        [ 0.8555],
        [ 0.7326],
        [ 0.7633],
        [ 1.2549],
        [ 1.2549],
        [ 1.0705],
        [-0.2813],
        [-1.2029],
        [ 1.0705],
        [-1.4180],
        [ 0.9476],
        [-1.2951],
        [ 0.5483],
        [-0.1891],
        [ 0.7326],
        [ 0.5175],
        [ 1.2549],
        [ 0.7326],
        [-0.7421],
        [-0.0969],
        [-1.0801],
        [-1.1415],
        [ 0.9476],
        [-0.9264],
        [ 0.6097],
        [ 1.2549],
        [-1.6331],
        [ 0.0874],
        [-0.4042],
        [-1.2951],
        [ 1.0091],
        [-0.4042],
        [-0.5885],
        [-1.

In [76]:
import plotly.graph_objects as go

def plot_errors_labels_comparison(model:GNN_naive_framework,graph:torch_geometric.data.Data,plot_attention_weights=False):
    if plot_attention_weights:
        preds, (adj, alpha) = model.predict(graph.x,
                                    graph.edge_index,
                                    graph.edge_attr,
                                    return_attention_weights=True)
    else:
        preds = model.predict(graph.x,
                                    graph.edge_index,
                                    graph.edge_attr,
                                    return_attention_weights=False)
        
    preds = np.array(preds.detach().to("cpu"))
    preds = np.squeeze(preds)
    labels = np.array(participant_graph.y)
    labels = np.squeeze(labels)

    errors = labels-preds


    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x = labels,
        y = errors,
        mode = "markers",
        marker=dict(color=preds)
    ))
    fig.update_layout(
        title="Residual depending on label value",
        xaxis_title="Label",
        yaxis_title="Residual"
    )
    fig.show()

    fig = go.Figure()
    fig.add_trace(go.Histogram(
        x = preds)
    )
    fig.update_layout(
        title="Residual depending on label value",
        xaxis_title="Label",
        yaxis_title="Residual"
    )
    fig.show()

    if plot_attention_weights:
        from torch_geometric.utils import (
            add_self_loops,
            is_torch_sparse_tensor,
            remove_self_loops,
            softmax,
            to_dense_adj
        )

        matrix_alpha = to_dense_adj(adj, edge_attr = alpha).cpu().detach()
        matrix_alpha = matrix_alpha.squeeze()
        fig = px.imshow(matrix_alpha)
        fig.update_layout(
            title="Alpha: the message passing strength between nodes"
        )
        fig.show()

plot_errors_labels_comparison(complete_model,participant_graph,False)

In [102]:
lin_src_params = [param for param in complete_model.update_node_module.lin_src.parameters()][0]
lin_dst_params = [param for param in complete_model.update_node_module.lin_dst.parameters()][0]

src_att_lin_params = lin_src_params * complete_model.update_node_module.att_src
dst_att_lin_params = lin_dst_params * complete_model.update_node_module.att_dst
#lin_params.squeeze()
print("param_x_src, param_x_dst =", float(src_att_lin_params), float(dst_att_lin_params))


lin_edge_params = [param for param in complete_model.update_node_module.lin_edge.parameters()][0]
att_lin_edge_params = lin_edge_params * complete_model.update_node_module.att_edge
print("params_edge", att_lin_edge_params.squeeze())

bias = [param for param in complete_model.update_node_module.bias][0]
print(bias)

param_x_src, param_x_dst = 0.6226500868797302 3384.351806640625
params_edge tensor(27.1259, device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor(50.3570, device='cuda:0', grad_fn=<UnbindBackward0>)


That param_x_dst is that small is a good sign. It means only $\alpha_{i,i}$ will keep an influence over the prediction of $x'_i$. The param

In [103]:
node_id_main = 0
edge_mask = participant_graph.edge_index[0,:] == node_id_main
edge_mask = edge_mask + participant_graph.edge_index[1,:] == node_id_main
edge_index_subgraph = participant_graph.edge_index[:,edge_mask]
edge_attr_subgraph = participant_graph.edge_attr[edge_mask]

participant_subgraph = Data(
        x = participant_graph.x, 
        edge_index = edge_index_subgraph,
        edge_attr = edge_attr_subgraph,
        y = participant_graph.x, 
        train_mask = participant_graph.train_mask, 
        val_mask = participant_graph.val_mask
        )

In [104]:
participant_subgraph.x

tensor([[ 21.],
        [ 21.],
        [ 80.],
        [ 78.],
        [ 59.],
        [ 75.],
        [ 88.],
        [ 90.],
        [ 88.],
        [ 85.],
        [  7.],
        [ 29.],
        [ 77.],
        [ 53.],
        [ 89.],
        [ 53.],
        [ 77.],
        [ 38.],
        [ 50.],
        [ 87.],
        [ 83.],
        [ 84.],
        [100.],
        [100.],
        [ 94.],
        [ 50.],
        [ 20.],
        [ 94.],
        [ 13.],
        [ 90.],
        [ 17.],
        [ 77.],
        [ 53.],
        [ 83.],
        [ 76.],
        [100.],
        [ 83.],
        [ 35.],
        [ 56.],
        [ 24.],
        [ 22.],
        [ 90.],
        [ 29.],
        [ 79.],
        [100.],
        [  6.],
        [ 62.],
        [ 46.],
        [ 17.],
        [ 92.],
        [ 46.],
        [ 40.],
        [  9.],
        [ 79.],
        [ 13.],
        [ 73.],
        [ 45.],
        [ 73.],
        [100.],
        [100.]])

In [105]:
complete_model.predict(node_attr=participant_subgraph.x,
                       edge_index=participant_subgraph.edge_index,
                       edge_attr=participant_subgraph.edge_attr)

tensor([[60.7998],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.3570],
        [50.

In [27]:
res = complete_model.update_node_module(participant_subgraph.x.to(device), participant_subgraph.edge_index.to(device), participant_subgraph.edge_attr.to(device),return_attention_weights=True)

for r in res:
    print(r)

tensor([[61.5520],
        [16.1193],
        [50.0500],
        [48.8998],
        [37.9730],
        [47.1745],
        [54.6508],
        [55.8010],
        [54.6508],
        [52.9255],
        [ 8.0679],
        [20.7201],
        [48.3247],
        [34.5224],
        [55.2259],
        [34.5224],
        [48.3247],
        [25.8959],
        [32.7971],
        [54.0757],
        [51.7753],
        [52.3504],
        [61.5520],
        [61.5520],
        [58.1014],
        [32.7971],
        [15.5442],
        [58.1014],
        [11.5185],
        [55.8010],
        [13.8189],
        [48.3247],
        [34.5224],
        [51.7753],
        [47.7496],
        [61.5520],
        [51.7753],
        [24.1707],
        [36.2477],
        [17.8446],
        [16.6944],
        [55.8010],
        [20.7201],
        [49.4749],
        [61.5520],
        [ 7.4928],
        [39.6983],
        [30.4967],
        [13.8189],
        [56.9512],
        [30.4967],
        [27.0461],
        [ 9.

Les derniers qui nous intéressent sont bien à 1 tandis que les autres sont à 0.

In [113]:
from examples.introduction.my_GAT_implem import myGATConv

my_module = myGATConv(
    in_channels=(1,1),
    out_channels=1,
    heads=1,
    negative_slope=0.2,
    add_self_loops=False,
    edge_dim=1)
# my_module(x=participant_graph.x,edge_index=participant_graph.edge_index,edge_attr=participant_graph.edge_attr)

complete_model = GNN_naive_framework(my_module,device)
opt = complete_model.configure_optimizer(lr=1)
scheduler = complete_model.configure_scheduler(opt,1,1,10)

complete_model.train([participant_graph],10000,1,opt,scheduler,"train_loss",100)

== start training ==
epoch: 1/10000,
 train_loss: 5378.2949,
 train_mae: 59.1030,
 epoch_time_duration: 0.0056

epoch: 2/10000,
 train_loss: 4373.0112,
 train_mae: 52.4354,
 epoch_time_duration: 0.0038

epoch: 3/10000,
 train_loss: 3608.8101,
 train_mae: 46.4671,
 epoch_time_duration: 0.0031

epoch: 4/10000,
 train_loss: 2933.8242,
 train_mae: 41.5882,
 epoch_time_duration: 0.0032

epoch: 5/10000,
 train_loss: 2374.7063,
 train_mae: 37.6855,
 epoch_time_duration: 0.0029

epoch: 6/10000,
 train_loss: 1913.9542,
 train_mae: 34.3410,
 epoch_time_duration: 0.0030

epoch: 7/10000,
 train_loss: 1548.0836,
 train_mae: 31.4919,
 epoch_time_duration: 0.0033

epoch: 8/10000,
 train_loss: 1272.1437,
 train_mae: 29.1906,
 epoch_time_duration: 0.0043

epoch: 9/10000,
 train_loss: 1079.2406,
 train_mae: 27.6272,
 epoch_time_duration: 0.0031

epoch: 10/10000,
 train_loss: 960.4120,
 train_mae: 26.8701,
 epoch_time_duration: 0.0037

epoch: 11/10000,
 train_loss: 904.7333,
 train_mae: 26.4361,
 epoch_t

{'train_loss': [tensor(5378.2949),
  tensor(4373.0112),
  tensor(3608.8101),
  tensor(2933.8242),
  tensor(2374.7063),
  tensor(1913.9542),
  tensor(1548.0836),
  tensor(1272.1437),
  tensor(1079.2406),
  tensor(960.4120),
  tensor(904.7333),
  tensor(899.6686),
  tensor(931.6901),
  tensor(987.1213),
  tensor(1053.0928),
  tensor(1118.4487),
  tensor(1174.4438),
  tensor(1215.1248),
  tensor(1237.3660),
  tensor(1240.6007),
  tensor(1226.3479),
  tensor(1197.6333),
  tensor(1158.4042),
  tensor(1112.9894),
  tensor(1065.6417),
  tensor(1020.1714),
  tensor(979.6805),
  tensor(946.3939),
  tensor(921.5851),
  tensor(905.5898),
  tensor(897.8991),
  tensor(897.3178),
  tensor(902.1699),
  tensor(910.5287),
  tensor(920.4459),
  tensor(930.1548),
  tensor(938.2309),
  tensor(943.6913),
  tensor(946.0344),
  tensor(945.2186),
  tensor(941.5944),
  tensor(935.8018),
  tensor(928.6495),
  tensor(920.9962),
  tensor(913.6401),
  tensor(907.2347),
  tensor(902.2318),
  tensor(898.8561),
  ten

In [110]:
import plotly.graph_objects as go

preds, (adj, alpha) = complete_model.predict(participant_graph.x,
                               participant_graph.edge_index,
                               participant_graph.edge_attr,
                               return_attention_weights=True)
preds = np.array(preds.detach().to("cpu"))
preds = np.squeeze(preds)
labels = np.array(participant_graph.y)
labels = np.squeeze(labels)

errors = np.abs(labels-preds)


fig = go.Figure()
fig.add_trace(go.Scatter(
    x = labels,
    y = errors,
    mode = "markers",
    marker=dict(color=preds)
))
fig.update_layout(
    title="Residual depending on label value",
    xaxis_title="Label",
    yaxis_title="Residual"
)
fig.show()

fig = go.Figure()
fig.add_trace(go.Histogram(
    x = preds)
)
fig.update_layout(
    title="Residual depending on label value",
    xaxis_title="Label",
    yaxis_title="Residual"
)
fig.show()

In [111]:
from torch_geometric.utils import (
    add_self_loops,
    is_torch_sparse_tensor,
    remove_self_loops,
    softmax,
    to_dense_adj
)

matrix_alpha = to_dense_adj(adj, edge_attr = alpha).cpu().detach()
matrix_alpha = matrix_alpha.squeeze()
fig = px.imshow(matrix_alpha)
fig.update_layout(
    title="Alpha: the message passing strength between nodes"
)
fig.show()

In [112]:
lin_src_params = [param for param in complete_model.update_node_module.lin_src.parameters()][0]
lin_dst_params = [param for param in complete_model.update_node_module.lin_dst.parameters()][0]

src_att_lin_params = lin_src_params * complete_model.update_node_module.att_src
dst_att_lin_params = lin_dst_params * complete_model.update_node_module.att_dst
#lin_params.squeeze()
print("param_x_src, param_x_dst =", float(src_att_lin_params), float(dst_att_lin_params))


lin_content_src_params = [param for param in complete_model.update_node_module.lin_src.parameters()][0]
lin_content_dst_params = [param for param in complete_model.update_node_module.lin_dst.parameters()][0]

#lin_params.squeeze()
print("lin_content_src_params, lin_content_src_params =", float(lin_content_src_params), float(lin_content_dst_params))

lin_edge_params = [param for param in complete_model.update_node_module.lin_edge.parameters()][0]
att_lin_edge_params = lin_edge_params * complete_model.update_node_module.att_edge
print("params_edge", att_lin_edge_params.squeeze())

bias = [param for param in complete_model.update_node_module.bias][0]
print(bias)

param_x_src, param_x_dst = -42.656185150146484 5.747586250305176
lin_content_src_params, lin_content_src_params = 6.158009052276611 -1.351555585861206
params_edge tensor(-26.7681, device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor(399.3427, device='cuda:0', grad_fn=<UnbindBackward0>)


In [49]:
participant_graph.x.shape

torch.Size([60, 1])

In [None]:
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import NeighborLoader

loader = NeighborLoader(
    participant_graph,
    # Sample 30 neighbors for each node for 2 iterations
    num_neighbors=[60],
    # Use a batch size of 128 for sampling training nodes
    batch_size=1,
    input_nodes=participant_graph.train_mask,
)

sampled_data = next(iter(loader))
print(sampled_data.batch_size)

### Parameters

In [None]:
for param in complete_model.update_node_module.parameters():
    print(param)

print("1")
print([param for param in complete_model.update_node_module.lin.parameters()])

print("att_src",complete_model.update_node_module.att_src)
print("att_dst",complete_model.update_node_module.att_dst)
print("lin",[param for param in complete_model.update_node_module.lin.parameters()])
print("lin_edge",[param for param in complete_model.update_node_module.lin_edge.parameters()])
print("att_edge",complete_model.update_node_module.att_edge[0])
print("bias",complete_model.update_node_module.bias)

lin_params = [param for param in complete_model.update_node_module.lin.parameters()][0]
src_att_lin_params = lin_params * complete_model.update_node_module.att_src
dst_att_lin_params = lin_params * complete_model.update_node_module.att_dst
#lin_params.squeeze()
print("param_x_src, param_x_dst =", float(src_att_lin_params), float(dst_att_lin_params))


lin_edge_params = [param for param in complete_model.update_node_module.lin_edge.parameters()][0]
att_lin_edge_params = lin_edge_params * complete_model.update_node_module.att_edge
print("params_edge", att_lin_edge_params.squeeze())