In [1]:
import torch
import numpy as np
import random
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm
from sklearn.decomposition import PCA
import seaborn as sns
import matplotlib.pyplot as plt

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print("device:", device)

device: mps


In [2]:
# Preprocess data

edge = torch.load("./Cora/edge.pt")
edge = torch.transpose(edge, 0, 1)

feat = torch.load("./Cora/feat.pt").type(torch.float32)

label = torch.load("./Cora/label.pt")

degree = torch.from_numpy(np.bincount(edge[:, 0]))

feat = feat.to(device)
degree = degree.to(device)
edge = edge.to(device)


print(f"# Nodes: {feat.shape[0]}")
print(f"# Edges: {edge.shape[0]}")

num_nodes, dim_feat = feat.shape
num_label = max(label) + 1
dim_hidden = 128

adj = torch.zeros((num_nodes, num_nodes), device=device)
print(edge)
for i, j in edge:
    adj[i, j] += 1

# Nodes: 2708
# Edges: 10556
tensor([[   0,  633],
        [   0, 1862],
        [   0, 2582],
        ...,
        [2707,  598],
        [2707, 1473],
        [2707, 2706]], device='mps:0')


In [3]:
class GraphSageLayer(nn.Module):

    def __init__(self, dim_in: int, dim_out: int):
        super(GraphSageLayer, self).__init__()

        self.dim_in = dim_in
        self.dim_out = dim_out
        self.act = nn.ReLU()

        self.weight = nn.Parameter(data=torch.zeros((dim_in, dim_out), device=device), requires_grad=False)
        nn.init.xavier_uniform_(self.weight)

    def forward(self, feat: torch.Tensor,
                adjacency: torch.Tensor,
                activate: bool) -> torch.Tensor:
        
        adj_eye = adjacency + torch.eye(num_nodes, device=device)
        degree = torch.sum(adj_eye, dim=1)
        # diag_inv_sqrt = torch.diag(torch.sqrt(1/degree))
        dgr_inv = torch.diag(1/degree)
        
        laplace_sym = dgr_inv @ adj_eye #@ diag_inv_sqrt
        out = laplace_sym @ feat @ self.weight
        if activate:
            out = self.act(out)
        
        return out
    
class GraphSage(nn.Module):

    def __init__(self, num_layers: int,
                 dim_in: int,
                 dim_hidden: int,
                 dim_out: int):

        super(GraphSage, self).__init__()

        self.num_layers = num_layers
        self.dim_in = dim_in
        self.dim_hidden = dim_hidden
        self.dim_out = dim_out

        self.layers = nn.ModuleList()
        for l in range(num_layers):
            self.layers.append(GraphSageLayer(self.dim_in if l == 0 else self.dim_hidden, self.dim_hidden if l < num_layers-1 else self.dim_out))


    def forward(self, feat: torch.Tensor,
                adj: torch.Tensor) -> torch.Tensor:

        x_in = feat
        for layer in self.layers:
            x_out = layer(x_in, adj, (False if self.layers[-1] == layer else True))
            x_in = x_out

        return x_out

In [37]:
LAYERS = list(range(1, 11))

dim_in = feat.shape[1]
dim_hidden = 700
dim_out = 100
SEED = 42

torch.manual_seed(SEED)
model = GraphSage(max(LAYERS), dim_in, dim_hidden, dim_out)
model.eval()

for num_layer in tqdm(LAYERS):

    feature_matrix = model(feat, adj)
    feature_matrix = feature_matrix.detach().cpu().numpy()
    
    pca= PCA(n_components=1)
    embedding = pca.fit_transform(feature_matrix)
    
    tmp = []
    for l in range(0, 7):
        class_embedding = embedding[label == l]
        print(num_layer, l)
        print(np.mean(class_embedding, axis=0))
        print(np.var(class_embedding, axis=0))
    

 10%|█         | 1/10 [00:00<00:01,  5.09it/s]

1 0
[-0.3035706]
[0.04478689]
1 1
[-0.28699145]
[0.05476156]
1 2
[0.19857413]
[0.10594685]
1 3
[0.24167086]
[0.09691481]
1 4
[0.11482083]
[0.05524124]
1 5
[-0.33257157]
[0.03470299]
1 6
[-0.34259793]
[0.05689025]
[-0.10152368]


 20%|██        | 2/10 [00:00<00:01,  4.02it/s]

2 0
[-0.09143926]
[0.00891152]
2 1
[0.05337184]
[0.0083474]
2 2
[0.0853442]
[0.01635472]
2 3
[0.10756865]
[0.03599006]
2 4
[-0.03925185]
[0.02028248]
2 5
[-0.21562187]
[0.01723618]
2 6
[-0.12319505]
[0.00563867]
[-0.03188905]


 30%|███       | 3/10 [00:00<00:02,  3.12it/s]

3 0
[-0.05228771]
[0.00664855]
3 1
[-0.10136565]
[0.00793473]
3 2
[-0.00506045]
[0.00754746]
3 3
[0.0149453]
[0.0149472]
3 4
[0.08894083]
[0.0177511]
3 5
[-0.00633751]
[0.00988528]
3 6
[-0.03200487]
[0.01112059]
[-0.01331001]


 40%|████      | 4/10 [00:01<00:02,  2.74it/s]

4 0
[-0.01136504]
[0.00251704]
4 1
[-0.02550538]
[0.00388542]
4 2
[-0.04681129]
[0.00383489]
4 3
[0.01473593]
[0.00799393]
4 4
[0.04350341]
[0.00795863]
4 5
[-0.01457595]
[0.00469774]
4 6
[0.01582268]
[0.00718148]
[-0.00345652]


 50%|█████     | 5/10 [00:01<00:01,  2.53it/s]

5 0
[-0.01441616]
[0.00241273]
5 1
[-0.04059884]
[0.00392766]
5 2
[-0.02904196]
[0.00269421]
5 3
[0.01954638]
[0.00621024]
5 4
[0.01225599]
[0.00519688]
5 5
[0.00591148]
[0.00293726]
5 6
[0.01687753]
[0.00522975]
[-0.00420937]


 60%|██████    | 6/10 [00:02<00:01,  2.36it/s]

6 0
[-0.00484653]
[0.00117377]
6 1
[-0.00912331]
[0.00134036]
6 2
[-0.01220683]
[0.00122678]
6 3
[0.00219186]
[0.00308836]
6 4
[0.00759378]
[0.00307685]
6 5
[-0.00166058]
[0.00146779]
6 6
[0.02361263]
[0.00284874]
[0.00079443]


 70%|███████   | 7/10 [00:02<00:01,  2.29it/s]

7 0
[-0.01318838]
[0.00059968]
7 1
[-0.02167068]
[0.00096998]
7 2
[-0.01475505]
[0.00075316]
7 3
[0.00978401]
[0.00190389]
7 4
[0.01635067]
[0.00182624]
7 5
[-0.00277606]
[0.00105942]
7 6
[0.00754389]
[0.00208766]
[-0.00267309]


 80%|████████  | 8/10 [00:03<00:00,  2.26it/s]

8 0
[-0.00510638]
[0.00013819]
8 1
[-0.01163035]
[0.0003318]
8 2
[-0.00602678]
[0.0002953]
8 3
[0.00137361]
[0.00072001]
8 4
[0.00691672]
[0.00067691]
8 5
[0.0030935]
[0.00041004]
8 6
[0.01024065]
[0.00062301]
[-0.00016272]


 90%|█████████ | 9/10 [00:03<00:00,  2.09it/s]

9 0
[-0.00401005]
[9.8971446e-05]
9 1
[-0.00590047]
[9.497875e-05]
9 2
[-0.00264975]
[0.0001177]
9 3
[0.00139007]
[0.00032522]
9 4
[0.00298975]
[0.00024643]
9 5
[0.00132903]
[0.00014815]
9 6
[0.00549301]
[0.00030914]
[-0.00019406]


100%|██████████| 10/10 [00:04<00:00,  2.34it/s]

10 0
[-0.00220952]
[7.005722e-05]
10 1
[-0.006438]
[9.736479e-05]
10 2
[-0.00677567]
[7.313288e-05]
10 3
[0.00263432]
[0.00024422]
10 4
[0.00546122]
[0.00016247]
10 5
[0.00012778]
[8.812154e-05]
10 6
[0.00269665]
[0.00016473]
[-0.00064332]





In [30]:
len(embedding[label==0])

351