In [None]:
import dgl
import dgl.nn.pytorch as dglnn
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import numpy as np
import pandas as pd
import glob
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
exc_feat_path = './data/train/feat/exc/'
dum_feat_path = './data/train/feat/dum/'
cos_feat_path = './data/train/feat/cos/'

exc_edge_path = './data/train/edge/exc/'
dum_edge_path = './data/train/edge/dum/'
cos_edge_path = './data/train/edge/cos/'


exc_edge_list = glob.glob(exc_edge_path + '*.txt')
dum_edge_list = glob.glob(dum_edge_path + '*.txt')
cos_edge_list = glob.glob(cos_edge_path + '*.txt')

exc_feat_list = glob.glob(exc_feat_path+'*.txt')
dum_feat_list = glob.glob(dum_feat_path+'*.txt')
cos_feat_list = glob.glob(cos_feat_path+'*.txt')

labels = pd.read_csv('./data/labels.txt', header=None, sep=' ').values

def create_graph(n):

    exc_edge = pd.read_csv(exc_edge_list[n], sep=',', header=None).to_numpy()  
    dum_edge = pd.read_csv(dum_edge_list[n], sep=',', header=None).to_numpy()
    cos_edge = pd.read_csv(cos_edge_list[n], sep=',', header=None).to_numpy()
    

    exc_src = exc_edge[:, 0]
    exc_dst = exc_edge[:, 1]

    dum_src = dum_edge[:, 0]
    dum_dst = dum_edge[:, 1]

    cos_src = cos_edge[:, 0]
    cos_dst = cos_edge[:, 1]
    # print(feature_data_list[n])
    # Read the entire feature data from the file
    
    exc_feat = pd.read_csv(exc_feat_list[n], sep='\t', header=None).to_numpy()
    dum_feat = pd.read_csv(dum_feat_list[n], sep='\t', header=None).to_numpy()
    cos_feat = pd.read_csv(cos_feat_list[n], sep='\t', header=None).to_numpy()

    # print(exc_feat.shape)

    # Define number of nodes for each type


    hetero_graph = dgl.heterograph({
        ('exc', 'inexc', 'exc'): (exc_src, exc_dst),
        ('exc', 'outexc', 'exc'): (exc_dst, exc_src),
        ('dum', 'indum', 'dum'): (dum_src, dum_dst), 
        ('dum', 'outdum', 'dum'): (dum_dst, dum_src),
        ('cos', 'incos', 'cos'): (cos_src, cos_dst), 
        ('cos', 'outcos', 'cos'): (cos_dst, cos_src), 
    })
    hetero_graph = hetero_graph.to(device)
    hetero_graph.nodes['exc'].data['feat'] = torch.tensor(exc_feat, dtype=torch.float32).to(device)
    hetero_graph.nodes['dum'].data['feat'] = torch.tensor(dum_feat, dtype=torch.float32).to(device)
    hetero_graph.nodes['cos'].data['feat'] = torch.tensor(cos_feat, dtype=torch.float32).to(device)

    label = labels[n]
    return hetero_graph, torch.LongTensor(label).to(device)

class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()
        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, hid_feats)
            for rel in rel_names}, aggregate='sum')
        self.conv2 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats, out_feats)
            for rel in rel_names}, aggregate='sum')

    def forward(self, graph, inputs):
        h = self.conv1(graph, inputs)
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.conv2(graph, h)
        return h

class HeteroClassifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes, rel_names):
        super().__init__()

        self.rgcn1 = RGCN(in_dim, hidden_dim, hidden_dim, rel_names)
        self.rgcn2 = RGCN(hidden_dim, hidden_dim, hidden_dim, rel_names)
        self.classify = nn.Linear(hidden_dim, n_classes)

    def forward(self, g):
        h = g.ndata['feat']
        h = self.rgcn1(g, h)
        h = self.rgcn2(g, h)
        with g.local_scope():
            g.ndata['h'] = h
            hg = 0
            for ntype in g.ntypes:
                hg = hg + dgl.mean_nodes(g, 'h', ntype=ntype)
            return self.classify(hg)

# print(dataset[0])

def collate(samples):
    # graphs, labels = map(list, zip(*samples))
    graphs, labels = zip(*samples)
    batched_graph = dgl.batch(graphs)
    batched_labels = torch.tensor(labels)
    return batched_graph, batched_labels


dataset = [create_graph(n) for n in range(len(exc_feat_list))]


# print(dataset)

In [None]:

train_dataset, val_dataset = train_test_split(dataset, test_size=0.1, random_state=137)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=256,
    collate_fn=collate,
    drop_last=False,
    shuffle=True)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=256,
    collate_fn=collate,
    drop_last=False)

etypes = dataset[0][0].etypes
# print(etypes)
model = HeteroClassifier(3, 256, 4, etypes).to(device)
opt = torch.optim.Adam(model.parameters())

loss_values = []
acc_values = []

for epoch in range(1000):
    # Training
    model.train()
    train_loss = 0
    for batched_graph, labels in train_dataloader:
        batched_graph = batched_graph.to(device)
        labels = labels.to(device)
        logits = model(batched_graph)
        loss = F.cross_entropy(logits, labels.squeeze(-1))
        opt.zero_grad()
        loss.backward()
        opt.step()
        train_loss += loss.item()
    print(f'Epoch {epoch}, Training Loss: {train_loss / len(train_dataloader)}')

    # Validation
    model.eval()
    val_loss = 0
    val_acc = 0
    all_predictions = []
    all_labels = []
for epoch in range(1000):
    # Training
    model.train()
    train_loss = 0
    for batched_graph, labels in train_dataloader:
        batched_graph = batched_graph.to(device)
        labels = labels.to(device)
        logits = model(batched_graph)
        loss = F.cross_entropy(logits, labels.squeeze(-1))
        opt.zero_grad()
        loss.backward()
        opt.step()
        train_loss += loss.item()
    print(f'Epoch {epoch}, Training Loss: {train_loss / len(train_dataloader)}')

    # Validation
    model.eval()
    val_loss = 0
    val_acc = 0
    with torch.no_grad():
        for batched_graph, labels in val_dataloader:
            batched_graph = batched_graph.to(device)
            labels = labels.to(device)
            logits = model(batched_graph)
            loss = F.cross_entropy(logits, labels.squeeze(-1))
            val_loss += loss.item()

            # Compute accuracy
            _, predicted = torch.max(logits, 1)
            val_acc += (predicted == labels.squeeze(-1)).sum().item() / len(labels)

        print(f'Epoch {epoch}, Validation Loss: {val_loss / len(val_dataloader)}, Validation Accuracy: {val_acc / len(val_dataloader)}')

