In [None]:
import torch
from torch_geometric.loader import DataLoader  # Use PyTorch Geometric's DataLoader
from torchvision import datasets, transforms    
import os
from _04_mnist_digits.graph_dataset import GraphDataset

# Define a transform to convert the images to tensors
transform = transforms.ToTensor()

# Download and load the MNIST dataset
train_dataset = datasets.MNIST(root='/data', train=True, transform=transform)
test_dataset = datasets.MNIST(root='/data', train=False, transform=transform)

In [2]:

images_train = train_dataset.data
labels_train = train_dataset.targets
images_test = test_dataset.data 
labels_test = test_dataset.targets

print(images_train.shape)
print(images_test.shape)
print(labels_train.shape)
print(labels_test.shape)

torch.Size([60000, 28, 28])
torch.Size([10000, 28, 28])
torch.Size([60000])
torch.Size([10000])


In [3]:
images = torch.cat((images_train, images_test), dim=0)
labels = torch.cat((labels_train, labels_test), dim=0)
print(images.shape)
print(labels.shape)

torch.Size([70000, 28, 28])
torch.Size([70000])


In [None]:
saved_dataset_path = '04_mnist_digits/data/graph_dataset.pt'

if os.path.exists(saved_dataset_path):
    print(f"Loading existing dataset from {saved_dataset_path}")
    dataset = torch.load(saved_dataset_path, weights_only=False)
    print(f"Loaded dataset with {len(dataset)} graphs")
else:
    print(f"Creating new dataset with {len(images)} images...")
    dataset = GraphDataset(images, labels, use_weighted_edges=True)
    print(f"Saving dataset to {saved_dataset_path}")
    torch.save(dataset, saved_dataset_path)
    print("Dataset saved!")

Creating new dataset with 70000 images...
GraphDataset.__init__ called with 70000 images


100%|██████████| 70000/70000 [11:36<00:00, 100.49it/s]


GraphDataset initialization complete!
Saving dataset to 04_mnist_digits/data/graph_dataset.pt
Dataset saved!


In [5]:
loader = DataLoader(dataset, batch_size=10, shuffle=False)

for i in range(10):
    print(dataset[i])

for batch in loader:
    print(batch)
    break  # Remove this line to iterate through the entire dataset

Data(x=[34], edge_index=[2, 111], y=5, edge_weight=[111])
Data(x=[34], edge_index=[2, 109], y=0, edge_weight=[109])
Data(x=[28], edge_index=[2, 86], y=4, edge_weight=[86])
Data(x=[19], edge_index=[2, 51], y=1, edge_weight=[51])
Data(x=[24], edge_index=[2, 78], y=9, edge_weight=[78])
Data(x=[27], edge_index=[2, 92], y=2, edge_weight=[92])
Data(x=[20], edge_index=[2, 58], y=1, edge_weight=[58])
Data(x=[30], edge_index=[2, 96], y=3, edge_weight=[96])
Data(x=[15], edge_index=[2, 41], y=1, edge_weight=[41])
Data(x=[20], edge_index=[2, 59], y=4, edge_weight=[59])
DataBatch(x=[251], edge_index=[2, 781], y=[10], edge_weight=[781], batch=[251], ptr=[11])
