In [None]:
import os
import re
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.data import Data, Dataset, DataLoader
from torch_geometric.loader import DataLoader

import numpy as np
from sklearn.preprocessing import LabelEncoder

In [None]:
# Import distance transformer
import torch_geometric.transforms as T

dist_transform = T.Distance()

# Import KNN transformer
knn_transform = T.KNNGraph()

In [None]:
def make_labels(file_path,skip=0):
    fp=file_path
    skip_lines = skip #default=0
          
    with open(fp, "r") as f:
        lines = f.readlines()

    # Organize the document by label
    #labels = re.findall(r'\b\w+\s\w+\b', text)
    current_label = None
    labels = {}

    for i,line in enumerate(lines):
        if i<skip_lines:
            continue
            
        # match labels
        match = re.search(r'^([^\d\s]+(?:\s+[^\d\s]+)*)', line)
        tokens = line.strip().split()
        
        if match:
            current_label = match.group(1).strip()
            labels[current_label] = []
        
        elif len(tokens) == 1:  # Instance index
            labels[current_label].append(int(tokens[0]))

    # Build a mapping from instance index to label
    mapping = {}
    for label, indices in labels.items():
        for index in indices:
            mapping[index] = label
            
    return mapping

In [None]:
def parse_off(file_path):
    """
    input: the filepath of the .off file
    output: the Data object after parsing that file
    """
    with open(file_path, 'r') as f:
        # Read the header and get the number of vertices and faces
        f.readline()#skip 1st line
        header = f.readline().strip().split(' ')
        num_vertices = int(header[0])
        num_faces = int(header[1])
        
        # Read the vertices
        vertices = []
        for i in range(num_vertices):
            vertex = list(map(float, f.readline().strip().split(' ')))
            vertices.append(vertex)
        vertices = torch.tensor(vertices)        
        
        # Read the faces and build the edges
        edges = []
        for i in range(num_faces):
            face = list(map(int, f.readline().strip().split(' ')[1:]))
            for j in range(len(face)):
                edge = (face[j], face[(j+1)%len(face)])
                edges.append(edge)
        edges = torch.tensor(edges, dtype=torch.long)
        
        #Pad and trim to match dimensionality
        #num_nodes = max(edges.max().item() + 1, vertices.size(0))
        #new_x = torch.zeros((num_nodes, vertices.size(1)))
        #new_x[:vertices.size(0), :] = vertices
        
        # creating positional matrix
        #pos = []
        #for i in range(num_vertices):
        #    pos.append([float(x) for x in f.readline().split()])
        pos = torch.tensor(vertices, dtype=torch.float)
        
        return Data(x=vertices, edge_index=edges.transpose(0,1),pos=pos)

def run_parse(root_fp, cla_fp):
    # Set the root folder that contains all the subfolders of .off files
    root_folder = root_fp
    
    out_dict = make_labels(cla_fp)
    ulab = np.unique(list(out_dict.values()))
    
    #  Label Encoder
    le = LabelEncoder()
    le.fit(ulab.reshape(-1, 1))

    # Get a mapping of idx:lab and generate label class
    out_dict = make_labels(cla_fp)
    ulab = np.unique(list(out_dict.values()))

    #Parsing all files and build a list of graph data objects
    graphs = []
    for filename in os.listdir(root_folder):
        if filename.endswith('.off'):#Check file type
            file_index = int(filename.split(".")[0])
            if out_dict.get(file_index): #check index included in classification
                filepath = os.path.join(root_folder,filename)
                graph = parse_off(filepath)
                file_label = le.transform([out_dict[file_index]])
                graph.y = torch.tensor(file_label[0])
                #add edge_attr
                graph = dist_transform(graph)
                #graph = knn_transform(graph)
                graphs.append(graph)
    return graphs

In [None]:
#Call dataloader directly
#loader = DataLoader(graphs, batch_size= 16, drop_last=True)
#torch.save(loader,'psb_loader')#this loader does not have attributes like num_classes

In [None]:
from torch_geometric.data import InMemoryDataset

class MyDataset(InMemoryDataset):
    def __init__(self, data_list):
        super().__init__()
        self.data_list = data_list
        
    @property
    def num_classes(self):
        num_classes = [data.y for data in self.data_list]
        return len(torch.unique(torch.tensor(num_classes)))
        
    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, index):
        return self.data_list[index]

In [None]:
#Running Module
graphs = run_parse(root_fp='...',cla_fp='...')
len(graphs)
psb_set = MyDataset(graphs)
torch.save(psb_set,'...')
#psb_set = torch.load('psb.pt')