In [2]:
%cd ..

/home/jbananafish/Desktop/Master/Thesis/code/gcnboost


In [30]:
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import GCNConv, SAGEConv, Linear, GraphConv, GATConv, to_hetero
import torch_geometric.transforms as T

from src.data.artgraph import ArtGraph

In [4]:
torch.manual_seed(1)
torch.cuda.manual_seed(1)

In [5]:
#base_data = ArtGraph("./ekg", preprocess='node2vec', features=True, type='ekg')
base_data = ArtGraph("data", preprocess='node2vec', transform=T.ToUndirected(), features=True, type='ekg')

In [19]:
data = base_data[0]

## Some graph-level information

In [5]:
print(f"Number of artist classes: {base_data.num_classes['artist']}")
print(f"Number of style classes: {base_data.num_classes['style']}")
print(f"Number of genre classes: {base_data.num_classes['genre']}")
print(f"Number of input features: {base_data.num_features}")

Number of artist classes: 300
Number of style classes: 83
Number of genre classes: 50
Number of input features: 128


## Some node-level information

In [6]:
data = base_data[0]

In [13]:
print(data)

HeteroData(
  [1martwork[0m={
    x=[61477, 128],
    y_artist=[61477],
    y_style=[61477],
    y_genre=[61477],
    train_mask=[61477],
    val_mask=[61477],
    test_mask=[61477]
  },
  [1martist[0m={ x=[300, 128] },
  [1mgallery[0m={ x=[1090, 128] },
  [1mcity[0m={ x=[665, 128] },
  [1mcountry[0m={ x=[64, 128] },
  [1mstyle[0m={ x=[83, 128] },
  [1mperiod[0m={ x=[53, 128] },
  [1mgenre[0m={ x=[50, 128] },
  [1mserie[0m={ x=[610, 128] },
  [1mauction[0m={ x=[5, 128] },
  [1mtag[0m={ x=[5146, 128] },
  [1mmedia[0m={ x=[160, 128] },
  [1msubject[0m={ x=[2161, 128] },
  [1mtraining_node[0m={ x=[108, 128] },
  [1mfield[0m={ x=[65, 128] },
  [1mmovement[0m={ x=[121, 128] },
  [1mpeople[0m={ x=[48, 128] },
  [1m(artist, influenced_rel, artist)[0m={ edge_index=[2, 62] },
  [1m(artist, subject_rel, subject)[0m={ edge_index=[2, 3648] },
  [1m(artist, training_rel, training_node)[0m={ edge_index=[2, 130] },
  [1m(artist, field_rel, field)[0m={ edge_in

In [31]:
class HomoGNN(torch.nn.Module):
    def __init__(self, operator=GCNConv, input_channels=128, hidden_channels=16, out_channels=300, num_layers=1, dropout=0.5, skip=False):
        super(HomoGNN, self).__init__()
        self.dropout = dropout
        self.skip = skip
        self.convs = torch.nn.ModuleList()
        self.lins = torch.nn.ModuleList()
        #self.convs.append(operator(input_channels, hidden_channels))
        for _ in range(num_layers):
            conv = operator(-1, hidden_channels)
            lin = Linear(-1, hidden_channels)
            self.convs.append(conv)
            self.lins.append(lin)
        self.conv_out = operator(hidden_channels, out_channels)
        
    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs):
            if self.skip:
                x = conv(x, edge_index).relu() + self.lins[i](x)
            else:
                x = conv(x, edge_index).relu()
            x = F.dropout(x, self.dropout)
        x = self.conv_out(x, edge_index)
        return F.log_softmax(x, dim=1)

In [32]:
class HomoSGNN(torch.nn.Module):
    def __init__(self, operator, input_channels, hidden_channels, out_channels, n_layers, dropout, skip):
        super(HomoSGNN, self).__init__()
        self.gnn = HomoGNN(operator, input_channels, hidden_channels, out_channels, n_layers, dropout, skip)

    def forward(self, x, edge_index):
        return [self.gnn(x, edge_index)]

class HomoMGNN(torch.nn.Module):
    def __init__(self, operator, input_channels, hidden_channels, out_channels, n_layers, dropout, skip):
        super(HomoMGNN, self).__init__()
        self.gnn_artist = HomoGNN(operator, input_channels, hidden_channels, out_channels['artist'], n_layers, dropout, skip)
        self.gnn_style = HomoGNN(operator, input_channels, hidden_channels, out_channels['style'], n_layers, dropout, skip)
        self.gnn_genre = HomoGNN(operator, input_channels, hidden_channels, out_channels['genre'], n_layers, dropout, skip)

    def forward(self, x, edge_index):
        return [self.gnn_artist(x, edge_index), self.gnn_style(x, edge_index), self.gnn_genre(x, edge_index)]
