In [1]:
import numpy as np
import os.path as osp
import pickle
import torch
import torch.utils
import torch.utils.data
import torch.nn.functional as F
from scipy.spatial.distance import cdist
from torch_geometric.utils import dense_to_sparse
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.datasets import MNISTSuperpixels
import random
from collections import defaultdict

In [2]:
def compute_adjacency_matrix_images(coord, sigma=0.1):
    coord = coord.reshape(-1, 2)
    dist = cdist(coord, coord)
    A = np.exp(- dist / (sigma * np.pi) ** 2)
    A[np.diag_indices_from(A)] = 0
    return A


In [3]:
dataset_raw = torch.load('./mnist/raw/MNISTSuperpixels.pt')

In [4]:
mean_px = dataset_raw[0][0]['x']
label_ex = (mean_px > 0).long()
print(label_ex)

tensor([[1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0]])


In [5]:
from tqdm import tqdm
# we filter the edges with a distance less than 0.1
dataset_filtered = []
for index, sample in enumerate(tqdm(dataset_raw[0])):
    mean_px, coord = sample['x'], sample['pos']
    coord = coord / 28
    A = compute_adjacency_matrix_images(coord)
    N_nodes = A.shape[0]
    label_ex = (mean_px > 0).long()
    A = torch.FloatTensor((A > 0.1) * A)
    edge_index, edge_attr = dense_to_sparse(A)
    x = mean_px.reshape(N_nodes, -1)
    coord = coord.reshape(N_nodes, 2)
    x = np.concatenate((x, coord), axis=1)
    row, col = edge_index
    dataset_filtered.append(
        Data(
            x=torch.tensor(x), 
            y=torch.LongTensor(sample['y']), 
            edge_index=edge_index,
            edge_attr=edge_attr,
            label_ex = label_ex
        )
    )
for index, sample in enumerate(tqdm(dataset_raw[1])):
    mean_px, coord = sample['x'], sample['pos']
    coord = coord / 28
    A = compute_adjacency_matrix_images(coord)
    N_nodes = A.shape[0]
    label_ex = (mean_px > 0).long()
    A = torch.FloatTensor((A > 0.1) * A)
    edge_index, edge_attr = dense_to_sparse(A)
    x = mean_px.reshape(N_nodes, -1)
    coord = coord.reshape(N_nodes, 2)
    x = np.concatenate((x, coord), axis=1)
    row, col = edge_index
    dataset_filtered.append(
        Data(
            x=torch.tensor(x), 
            y=torch.LongTensor(sample['y']), 
            edge_index=edge_index,
            edge_attr=edge_attr,
            label_ex = label_ex
        )
    )


100%|██████████| 60000/60000 [00:32<00:00, 1843.86it/s]
100%|██████████| 10000/10000 [00:04<00:00, 2036.66it/s]


In [22]:
# We define 0, 1,  4, 6,  9 as Meta_Train_classes, 2, 5, 8 as Meta_Test classes, 3,7 as Meta_Validation classes

In [6]:
label_to_data_dict = defaultdict(list)

In [7]:
for data in dataset_filtered:
    label = data['y'].item()
    label_to_data_dict[label].append(data)

In [8]:
train_label_to_data_dict, test_label_to_data_dict, val_label_to_data_dict = defaultdict(list), defaultdict(list), defaultdict(list)

In [9]:
for i in [0, 1, 4, 6, 9]:
    train_label_to_data_dict[i] = label_to_data_dict[i]
    print(len(label_to_data_dict[i]))

for i in [3,7]:
    val_label_to_data_dict[i] = label_to_data_dict[i]
    print(len(label_to_data_dict[i]))
    
for i in [2, 5, 8]:
    test_label_to_data_dict[i] = label_to_data_dict[i]
    print(len(label_to_data_dict[i]))

6903
7877
6824
6876
6958
7141
7293
6990
6313
6825


In [10]:
val_label_to_data_dict

defaultdict(list,
            {3: [Data(x=[75, 3], edge_index=[2, 748], edge_attr=[748], y=[1], label_ex=[75, 1]),
              Data(x=[75, 3], edge_index=[2, 708], edge_attr=[708], y=[1], label_ex=[75, 1]),
              Data(x=[75, 3], edge_index=[2, 740], edge_attr=[740], y=[1], label_ex=[75, 1]),
              Data(x=[75, 3], edge_index=[2, 698], edge_attr=[698], y=[1], label_ex=[75, 1]),
              Data(x=[75, 3], edge_index=[2, 804], edge_attr=[804], y=[1], label_ex=[75, 1]),
              Data(x=[75, 3], edge_index=[2, 822], edge_attr=[822], y=[1], label_ex=[75, 1]),
              Data(x=[75, 3], edge_index=[2, 726], edge_attr=[726], y=[1], label_ex=[75, 1]),
              Data(x=[75, 3], edge_index=[2, 770], edge_attr=[770], y=[1], label_ex=[75, 1]),
              Data(x=[75, 3], edge_index=[2, 774], edge_attr=[774], y=[1], label_ex=[75, 1]),
              Data(x=[75, 3], edge_index=[2, 762], edge_attr=[762], y=[1], label_ex=[75, 1]),
              Data(x=[75, 3], edge_inde

In [11]:
for i in train_label_to_data_dict:
    print(len(train_label_to_data_dict[i]))
for i in val_label_to_data_dict:
    print(len(val_label_to_data_dict[i]))
for i in test_label_to_data_dict:
    print(len(test_label_to_data_dict[i]))

6903
7877
6824
6876
6958
7141
7293
6990
6313
6825


In [12]:
train_data, val_data, test_data = [], [], []
for i in train_label_to_data_dict:
    train_data.extend(train_label_to_data_dict[i])
for i in val_label_to_data_dict:
    val_data.extend(val_label_to_data_dict[i])
for i in test_label_to_data_dict:
    test_data.extend(test_label_to_data_dict[i])
torch.save(train_data, './mnist/train.pt')
torch.save(val_data, './mnist/val.pt')
torch.save(test_data, './mnist/test.pt')