### Homework 2: Graph Neural Networks

- Author: Ludek Cizinsky (`ludek.cizinsky@epfl.ch`)

In [44]:
# Hugging face util to download dataset
from datasets import load_dataset

# Scikit-learn for train-test split
from sklearn.model_selection import train_test_split

# PyTorch
import torch
# - Dataloader
from torch.utils.data import Dataset, DataLoader

# Custom scripts
from scripts.dataset import GraphDataset


In [45]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Load the dataset

Some important notes:

- the dataset includes collection of chemical compounds represented as graphs (details are specified below), i.e., one sample is a graph and we have a corresponding ground truth label indicating whether the compound is mutagenic or not
- On the lower level, each node has associated embedding (one hot encoding) indicating its type, same goes for the edges, **the dimensions of these embeddings are different**

#### Download the dataset from Hugging Face (HF)

In [46]:
dataset_hf = load_dataset("graphs-datasets/MUTAG")['train']

#### Train, validation and test split

In [47]:
# Parse the dataset into X and y
X, y = [], []
for s in dataset_hf:
    X.append(s)
    y.extend(s['y'])

# Split the dataset into train and test
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Split the train dataset into train and validation
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)

#### Load it using custom dataloader

In [48]:
# Define hyperparameters for the dataloader
batch_size = 32
shuffle = True

# Define custom stacking behavior for the dataloader
def collate_fn(batch):
    return tuple(zip(*batch))

# Define custom datasets
train_dataset = GraphDataset(X_train, y_train)
val_dataset = GraphDataset(X_val, y_val)
test_dataset = GraphDataset(X_test, y_test)

# Define custom dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size, 
    shuffle=shuffle,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size, 
    shuffle=shuffle,
    collate_fn=collate_fn
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size, 
    shuffle=shuffle,
    collate_fn=collate_fn
)