PACKAGE

In [1]:
from tqdm import tqdm, trange
import argparse
from PIL import Image

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
from torch_geometric.nn import Sequential, GATConv
from torch_geometric.data import Data

In [4]:
import numpy as np
import pandas as pd

In [5]:
import networkx as nx

CONFIG

In [6]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [7]:
tqdm.pandas()

In [8]:
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type = str, default = 'Cora')
parser.add_argument('--hidden_channels', type = int, default = 8)
parser.add_argument('--heads', type = int, default = 8)
parser.add_argument('--lr', type = float, default = 0.005)
parser.add_argument('--epochs', type = int, default = 200)
parser.add_argument('--wandb', action = 'store_true', help = 'Track experiment')
parser.add_argument('--superpixel', type = str, default = 'slic')
parser.add_argument('--image', type = str, default = 'rgb')
args = parser.parse_known_args()[0]

DATASET

In [10]:
selected_df = pd.read_csv('./select.csv', delimiter = ',')
selected_df['id'] = range(len(selected_df))
selected_df = selected_df.loc[:5]                 # HACK: Limit the number for classification

In [11]:
class2id = {directory: id for directory, id in zip(selected_df['directory'], selected_df['id'])}

In [18]:
dfs = []

for filename in tqdm(os.listdir('./filtered/train')):
    df = pd.read_pickle(f'./filtered/train/{filename}')
    df = df[df['label'].apply(lambda x : x in class2id.keys())]
    df['label'] = df['label'].apply(lambda x : class2id[x])
    df = df[[f'{args.superpixel}_{args.image}_global_graph', f'{args.superpixel}_{args.image}', f'{args.image}', 'label']]
    df.columns = ['graph', 'superpixel', 'image', 'label']
    dfs.append(df)

train_df = pd.concat(dfs).reset_index(drop = True)

100%|██████████| 100/100 [09:34<00:00,  5.74s/it]


In [19]:
dfs = []

for filename in tqdm(os.listdir('./filtered/val')):
    df = pd.read_pickle(f'./filtered/val/{filename}')
    df = df[df['label'].apply(lambda x : x in class2id.keys())]
    df['label'] = df['label'].apply(lambda x : class2id[x])
    df = df[[f'{args.superpixel}_{args.image}_global_graph', f'{args.superpixel}_{args.image}', f'{args.image}', 'label']]
    df.columns = ['graph', 'superpixel', 'image', 'label']
    dfs.append(df)

valid_df = pd.concat(dfs).reset_index(drop = True)

100%|██████████| 30/30 [00:20<00:00,  1.47it/s]


In [20]:
num_classes = train_df['label'].max()
num_classes

5

FEATURE ENGINEERING

In [None]:
def get_supix_statistics(data):
    graph = data['graph']
    superpixel = data['superpixel']
    image = data['image']
    label = data['label']
    num_superpixel = data['superpixel'].max()

    means = []
    stds = []
    centroids = []

    for supix in range(superpixel.max()):
        mask = superpixel != supix      # Mask out the pixels that are not equal to given superpixel label.
        trinary_mask = np.stack([mask, mask, mask], axis = 2)

        masked = np.ma.masked_array(image, trinary_mask)
        mean = np.ma.mean(masked, axis = (0, 1))
        std = np.ma.std(masked, axis = (0, 1))
        centroid = np.array([np.mean(subset) for subset in np.nonzero(np.logical_not(mask))])

        means.append(mean.data)
        stds.append(std.data)
        centroids.append(centroid)

    return means, stds, centroids

In [None]:
ret = df.progress_apply(get_supix_statistics, axis = 1)
df['means'] = [r[0] for r in ret]
df['stds'] = [r[1] for r in ret]
df['centroids'] = [r[2] for r in ret]

CONSTRUCT NODE FEATURES

In [None]:
def construct_node_features(data):
    graph = data['graph']
    means = data['means']
    stds = data['stds']
    centroids = data['centroids']
    weights = dict()

    for index, (mean, std, centroid) in enumerate(zip(means, stds, centroids)):
        weight = np.concatenate([mean, std, centroid])
        weights[index] = weight
    
    return weights

In [None]:
def construct_new_graph(data):
    graph = data['graph']
    attribute = data['attributes']
    nx.set_node_attributes(graph, attribute, name = 'features')
    return graph

In [None]:
df['attributes'] = df.progress_apply(construct_node_features, axis = 1)

In [None]:
df['graph'] = df.progress_apply(construct_new_graph, axis = 1)

MODEL

In [None]:
class SPINCS(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads):
        super(SPINCS, self).__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads, dropout = 0.5)
        self.conv2 = GATConv(hidden_channels * heads, out_channels, heads, dropout = 0.5)

    def forward(self, x, edge_index):
        x = F.dropout(x, p = 0.5, training = self.training)
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p = 0.5, training = self.training)
        x = self.conv2(x, edge_index)
        return x

In [None]:
model = SPINCS(8, args.hidden_channels, num_classes, args.heads).to(device)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr = 0.005, weight_decay = 1e-5)

TRAIN

In [21]:
from torch_geometric.datasets import Planetoid

In [22]:
dataset = Planetoid(root='./cora', name='Cora')

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


In [23]:
dataset

Cora()

In [None]:
dfs[32]