In [None]:
import os, torch
from sklearn.model_selection import train_test_split
import pickle
import torch_geometric.transforms as T
import numpy as np
from torch_geometric.nn.models import Node2Vec
from torch_geometric.data import DataLoader
from torch_geometric.nn import MessagePassing
from torch_geometric.data import Data
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler

epochs = int(os.getenv("EPOCHS", 1000))  # Default to 10 if not provided
learning_rate = float(os.getenv("LEARNING_RATE", 0.001))  # Default to 0.001
hidden_c = int(os.getenv("HIDDEN_C", 32))  # Default to 16
random_seed = int(os.getenv("RANDOM_SEED", 42))  # Default to 42
api_key = os.getenv("API_KEY", None)
graph_num = os.getenv("GRAPH_NUM", 12)
dropout_p = float(os.getenv("DROPOUT", 0.5))

# wandb.login()
# run = wandb.init(
#     project="graph-embedding",
#     config={
#         "epochs": epochs,
#         "learning_rate": learning_rate,
#         "hidden_c": hidden_c,
#         "random_seed": random_seed,
#         "num_layers": num_layers,
#         "dropout_p": dropout_p
#     }
# )

if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"Using CUDA device: {torch.cuda.get_device_name(0)}", flush = True)
else:
    device = torch.device('cpu')
    print("Using CPU", flush = True)

### load graph data

with open(f'../data/graphs/{graph_num}/linegraph_tg.pkl', 'rb') as f:
    data = pickle.load(f)

data.edge_index = data.edge_index.contiguous()
data.x = data.x.contiguous()
data.y = data.y.contiguous()

sc = StandardScaler()
data.x = torch.tensor(sc.fit_transform(data.x.cpu().numpy()))


In [None]:
train_loader, test_loader = train_test_split(data.x, test_size=0.2, random_state=random_seed)
train_loader = DataLoader(train_loader, batch_size=64, shuffle=True)
test_loader = DataLoader(test_loader, batch_size=64, shuffle=False)

class my_autoencoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, dropout_p):
        super(my_autoencoder, self).__init__()
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(in_channels, hidden_channels),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout_p),
            torch.nn.Linear(hidden_channels, hidden_channels//2),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout_p),
            torch.nn.Linear(hidden_channels//2, hidden_channels//4),
        )
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(hidden_channels//4, hidden_channels//2),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout_p),
            torch.nn.Linear(hidden_channels//2, hidden_channels),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout_p),
            torch.nn.Linear(hidden_channels, in_channels),
        )
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

model = my_autoencoder(data.x.shape[1], hidden_c, dropout_p).to(device)
print(f"Model: {model}", flush = True)
criterion = torch.nn.MSELoss()


In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10)

def train():
    model.train()
    total_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        batch = batch.to(device)
        out = model(batch)
        loss = criterion(out, batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

def test(loader):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            out = model(batch)
            loss = criterion(out, batch)
            total_loss += loss.item()
    return total_loss / len(loader)

for epoch in range(1, epochs + 1):
    train_loss = train()
    test_loss = test(test_loader)
    scheduler.step(train_loss)
    print(f'Epoch: {epoch}, Train Loss: {train_loss}, Test Loss: {test_loss:.4f} current_lr : {scheduler.get_last_lr()}', flush = True)
print('Training complete', flush = True)


In [None]:
def stratified_split(data, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
    """Splits data into train, validation, and test sets, stratifying by y > 0."""

    # Create a boolean mask for nodes where y > 0
    positive_mask = data.y > 0

    # Get indices of positive and negative nodes
    positive_indices = positive_mask.nonzero(as_tuple=False).squeeze()
    negative_indices = (~positive_mask).nonzero(as_tuple=False).squeeze()

    # Split positive indices
    pos_train_idx, pos_temp_idx = train_test_split(positive_indices, train_size=train_ratio, random_state=random_seed)  # Adjust random_state for consistent splits
    pos_val_idx, pos_test_idx = train_test_split(pos_temp_idx, test_size=(test_ratio / (val_ratio + test_ratio)), random_state=random_seed)

    # Split negative indices
    neg_train_idx, neg_temp_idx = train_test_split(negative_indices, train_size=train_ratio, random_state=random_seed)
    neg_val_idx, neg_test_idx = train_test_split(neg_temp_idx, test_size=(test_ratio / (val_ratio + test_ratio)), random_state=random_seed)

    # Combine indices
    train_idx = torch.cat([pos_train_idx, neg_train_idx])
    val_idx = torch.cat([pos_val_idx, neg_val_idx])
    test_idx = torch.cat([pos_test_idx, neg_test_idx])

    # Create masks
    train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    val_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)

    train_mask[train_idx] = True
    val_mask[val_idx] = True
    test_mask[test_idx] = True

    data.train_mask = train_mask
    data.val_mask = val_mask
    data.test_mask = test_mask

    return data

data = stratified_split(data, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15)


In [None]:
bins = [int(i) for i in os.getenv("BINS", "3000").split(' ')]  # Default to [1000, 3000, 5000]


model.decoder = torch.nn.Sequential(
    torch.nn.Linear(hidden_c//4, hidden_c//2),
    torch.nn.ReLU(),
    torch.nn.Dropout(dropout_p),
    torch.nn.Linear(hidden_c//2, hidden_c),
    torch.nn.ReLU(),
    torch.nn.Dropout(dropout_p),
    torch.nn.Linear(hidden_c, len(bins) + 1),
)

### freeze encoder
for param in model.encoder.parameters():
    param.requires_grad = False
for param in model.decoder.parameters():
    param.requires_grad = True
model = model.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10)


In [None]:
bins = torch.tensor(bins, device=device)


In [None]:
def train():
    model.train()
    data.x = data.x.to(device)
    data.edge_index = data.edge_index.to(device)
    data.test_mask = data.test_mask.to(device)
    data.y = data.y.to(device)
    optimizer.zero_grad()  # Clear gradients.
    mask = data.train_mask.squeeze().to(device) & (data.y > 0).squeeze().to(device)
    out = model(data.x)  # Perform a single forward pass.
    # Convert target to 1D tensor with dtype=torch.long
    target = torch.bucketize(data.y[mask], bins).squeeze()
    loss = criterion(out[mask], target.long())  # Ensure target is 1D and long
    loss.backward()  # Derive gradients.
    optimizer.step()  # Update parameters based on gradients.
    return loss

def test():
    model.eval()
    data.x = data.x.to(device)
    data.edge_index = data.edge_index.to(device)
    data.val_mask = data.val_mask.to(device)
    data.y = data.y.to(device)
    mask = data.val_mask.squeeze() & (data.y > 0).squeeze()
    out = model(data.x)
    target = torch.bucketize(data.y[mask], bins).squeeze()
    loss = criterion(out[mask], target.long())  # Ensure target is 1D and long
    correct_preds = out[mask].argmax(dim=1)
    correct = (correct_preds == target).sum()
    accuracy = correct.item() / mask.sum().item()
    return accuracy, out, loss


In [None]:
for epoch in range(1, epochs + 1):
    train_loss = train()
    val_accuracy, out, val_loss = test()
    scheduler.step(train_loss)
    print(f'Epoch: {epoch}, Train Loss: {train_loss}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}', flush = True)
