### Graph Auto-Encoders

In this exercise, we will implement an GCN based autoencoder to learn better embeddings to perform dimensionality reduction. 

Autoencoders such as the ones we will implement today can be thought to be non-linear PCAs. 

We begin by preparing the data:

1. Load the tensor data for CHILI-Challenge and labels
2. Create the Torch dataset, dataloaders
3. Complete the autoencoder forward call.
4. Visualize the low dimensional embeddings

In [None]:
import os
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import download_url, extract_zip
from torch_geometric.nn import GCNConv
from torch_geometric.loader import DataLoader
import pdb
import matplotlib.pyplot as plt

In [None]:
data_loc = './Data'
if not Path(data_loc).exists():
    Path(data_loc).mkdir()
else:
    print(data_loc, 'exists!')

In [None]:
url = 'https://sid.erda.dk/share_redirect/h7plnJoaYR/CHILI-Challenge.zip'
dataset_name = 'CHILI-Challenge'

if not Path(data_loc+'/'+dataset_name+'.zip').exists():
    print('Data not locally found. Downloading...')
    path = download_url(url,data_loc)
    print('Extracting data...')
    extract_zip(path, data_loc)
    print('Done!')
else:
    print('Data found at '+data_loc)
    data_loc += '/'+dataset_name
    print(os.listdir(data_loc))

In [None]:
class GraphAutoencoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(GraphAutoencoder, self).__init__()

        # Encoder layers
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, latent_dim)

        # Decoder layers
        self.deconv1 = GCNConv(latent_dim, hidden_dim)
        self.deconv2 = GCNConv(hidden_dim, input_dim)

    def encode(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        z = (self.conv2(x, edge_index))
        
        return z

    def decode(self, z, edge_index):
        z = F.relu(self.deconv1(z, edge_index))
        xHat = self.deconv2(z, edge_index)
        return xHat

    def forward(self, x, edge_index):
        z = self.encode(x, edge_index) # obtain latent representations from encoder
        x_hat = self.decode(z,edge_index) # obtain reconstructed node features from decoder
        return x_hat, z

In [None]:
train_set = torch.load('Data/CHILI-Challenge/train.pt')

In [None]:
train_set[0]

In [None]:
# Example usage
input_dim = 10  # Input node feature dimension
hidden_dim = 32  # Hidden layer dimension
latent_dim = 2  # Latent space dimension
B =16
num_epochs = 10
train_loader = DataLoader(train_set,batch_size=B)

model = GraphAutoencoder(input_dim, hidden_dim, latent_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Assuming you have a molecular graph dataset with node features x and edge_index
# Iterate over your dataset and train the model
for epoch in range(num_epochs):
    for data in train_loader:
        x, edge_index = torch.cat((data.x, data.pos_abs,data.pos_frac),dim=1), data.edge_index

        optimizer.zero_grad()
        x_hat, z = model(x, edge_index)
        loss = F.mse_loss(x_hat, x)  # Mean Squared Error loss
        loss.backward()
        optimizer.step()

    print(f'Epoch {epoch + 1}, Loss: {loss.item()}')


In [None]:
# Low dimensions using handcrafted features
feature_1 = [t.x.mean() for t in train_set] #
feature_2 = [t.x.shape[0] for t in train_set]#
labels = [t.y['crystal_type_number'] for t in train_set] #
plt.figure(figsize=(12,8))
plt.scatter(feature_1,feature_2,c=labels)
plt.show()

In [None]:
# Low dimensional features using learnt features
embeddings = []
B = 1
train_loader = DataLoader(train_set,batch_size=B)
with torch.no_grad():
    for data in train_loader:
            x, edge_index = torch.cat((data.x, data.pos_abs,data.pos_frac),dim=1), data.edge_index
            _, z = model(x, edge_index)
            embeddings.append(z.mean(0))
embeddings = torch.stack(embeddings)

In [None]:
plt.figure(figsize=(12,8))
plt.scatter(embeddings[:,0],embeddings[:,1],c=labels)
plt.show()