In [None]:
import torch_geometric as pyg

In [None]:
import os, torch
from sklearn.model_selection import train_test_split
import pickle
import torch_geometric.transforms as T
import numpy as np
from torch_geometric.nn.models import Node2Vec
from torch_geometric.data import DataLoader
from torch_geometric.nn import MessagePassing
from torch_geometric.data import Data
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler

epochs = 100000
learning_rate = 0.0001
dropout_p = 0
hidden_c = 256
num_layers = 6


if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"Using CUDA device: {torch.cuda.get_device_name(0)}", flush = True)
else:
    device = torch.device('cpu')
    print("Using CPU", flush = True)

### load graph data

with open(f'../data/graphs/2/linegraph_tg.pkl', 'rb') as f:
    data = pickle.load(f)

data.edge_index = data.edge_index.contiguous()
data.x = data.x.contiguous()
data.y = data.y.contiguous()

sc = StandardScaler()
data.x = torch.tensor(sc.fit_transform(data.x.cpu().numpy()))


class GAE(torch.nn.Module):
    def __init__(self, encoder, decoder):
        super(GAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder


    def forward(self, x, edge_index):
        z = self.encoder(x, edge_index)
        return self.decoder(z)

class encoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels, hidden_channels, num_layers, dropout_p):
        super(encoder, self).__init__()
        self.conv1 = GCNConv(data.num_features, hidden_channels, improved = True, cached = True)
        conv2_list = []
        hc = hidden_channels
        for _ in range(num_layers):
            conv2_list.append(GCNConv(hc, hc//2, improved= True, cached = True))
            hc //= 2
        self.conv2 = torch.nn.ModuleList(conv2_list)
        # self.conv3 = GCNConv(hc, 30, cached = True)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=dropout_p, training=self.training)
        for conv in self.conv2:
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=dropout_p, training=self.training)
        # x = self.conv3(x, edge_index)
        return x

class decoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(decoder, self).__init__()
        self.lin = Linear(in_channels, out_channels)
    
    def forward(self, x):
        x = self.lin(x)
        return x

model = GAE(encoder(data.num_features, hidden_c, hidden_c, num_layers, dropout_p),
             decoder(4, data.num_features)).to(device)


In [None]:
print(model, flush=True)
torch.save(model, f"../data/graphs/2/models/test_model.pt")

In [None]:

# Move data to device
data.x = data.x.to(device)
data.edge_index = data.edge_index.to(device)

criterion = torch.nn.MSELoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4)


def train(data):
    model.train()
    optimizer.zero_grad()
    z = model.encoder(data.x, data.edge_index)
    z = model.decoder(z)
    loss = criterion(z, data.x)
    loss.backward()
    optimizer.step()
    return loss.item()

def test(data):
    model.eval()
    with torch.no_grad():
        z = model.encoder(data.x, data.edge_index)
        z = model.decoder(z)
        loss = criterion(z, data.x)
    return loss.item()

def train_model(data, epochs):
    losses = []
    epoch_list = []
    fig, ax = plt.subplots()
    for epoch in range(epochs):
        loss = train(data)
        losses.append(loss)
        epoch_list.append(epoch)
        if epoch % 1000 == 0:
            print(f'Epoch {epoch}, Loss: {loss:.9f}', flush=True)
    # return model.encoder(data.x, data.edge_index).cpu().numpy()



In [None]:
learning_rate = 0.00001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4)
train_model(data, 50000)
torch.save(model, f"../data/graphs/2/models/test_model.pt")