In [2]:
%cd ..

/home/jbananafish/Desktop/Master/Thesis/code/gcnboost


In [3]:
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, SAGEConv, Linear, GraphConv, GATConv, to_hetero
import torch_geometric.transforms as T
import torch_geometric.nn as operators

from src.data.artgraph import ArtGraph

In [4]:
torch.manual_seed(1)
torch.cuda.manual_seed(1)

In [5]:
#base_data = ArtGraph("./ekg", preprocess='node2vec', features=True, type='ekg')
base_data = ArtGraph("data", preprocess='node2vec', transform=T.ToUndirected(), features=True, type='ekg')

In [6]:
data = base_data[0]

## Some graph-level information

In [7]:
print(f"Number of artist classes: {base_data.num_classes['artist']}")
print(f"Number of style classes: {base_data.num_classes['style']}")
print(f"Number of genre classes: {base_data.num_classes['genre']}")
print(f"Number of input features: {base_data.num_features}")

Number of artist classes: 300
Number of style classes: 83
Number of genre classes: 50
Number of input features: 128


## Some node-level information

In [8]:
data = base_data[0]

In [9]:
print(data)

HeteroData(
  [1martwork[0m={
    x=[61477, 128],
    y_artist=[61477],
    y_style=[61477],
    y_genre=[61477],
    train_mask=[61477],
    val_mask=[61477],
    test_mask=[61477]
  },
  [1martist[0m={ x=[300, 128] },
  [1mgallery[0m={ x=[1090, 128] },
  [1mcity[0m={ x=[665, 128] },
  [1mcountry[0m={ x=[64, 128] },
  [1mstyle[0m={ x=[83, 128] },
  [1mperiod[0m={ x=[53, 128] },
  [1mgenre[0m={ x=[50, 128] },
  [1mserie[0m={ x=[610, 128] },
  [1mauction[0m={ x=[5, 128] },
  [1mtag[0m={ x=[5146, 128] },
  [1mmedia[0m={ x=[160, 128] },
  [1msubject[0m={ x=[2161, 128] },
  [1mtraining_node[0m={ x=[108, 128] },
  [1mfield[0m={ x=[65, 128] },
  [1mmovement[0m={ x=[121, 128] },
  [1mpeople[0m={ x=[48, 128] },
  [1m(artist, influenced_rel, artist)[0m={ edge_index=[2, 62] },
  [1m(artist, subject_rel, subject)[0m={ edge_index=[2, 3648] },
  [1m(artist, training_rel, training_node)[0m={ edge_index=[2, 130] },
  [1m(artist, field_rel, field)[0m={ edge_in

In [10]:
class HomoGNN(torch.nn.Module):
    def __init__(self, operator=GCNConv, input_channels=128, hidden_channels=16, out_channels=300, num_layers=1, dropout=0.5, skip=False):
        super(HomoGNN, self).__init__()
        self.dropout = dropout
        self.skip = skip
        self.convs = torch.nn.ModuleList()
        self.lins = torch.nn.ModuleList()
        #self.convs.append(operator(input_channels, hidden_channels))
        for _ in range(num_layers):
            conv = operator(-1, hidden_channels)
            lin = Linear(-1, hidden_channels)
            self.convs.append(conv)
            self.lins.append(lin)
        self.conv_out = operator(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs):
            if self.skip:
                x = conv(x, edge_index).relu() + self.lins[i](x)
            else:
                x = conv(x, edge_index).relu()
            x = F.dropout(x, self.dropout)
        x = self.conv_out(x, edge_index)
        return F.log_softmax(x, dim=1)

In [11]:
class HomoSGNN(torch.nn.Module):
    def __init__(self, operator, input_channels, hidden_channels, out_channels, n_layers, dropout, skip):
        super(HomoSGNN, self).__init__()
        self.gnn = HomoGNN(operator, input_channels, hidden_channels, out_channels, n_layers, dropout, skip)

    def forward(self, x, edge_index):
        return [self.gnn(x, edge_index)]

class HomoMGNN(torch.nn.Module):
    def __init__(self, operator, input_channels, hidden_channels, out_channels, n_layers, dropout, skip):
        super(HomoMGNN, self).__init__()
        self.gnn_artist = HomoGNN(operator, input_channels, hidden_channels, out_channels['artist'], n_layers, dropout, skip)
        self.gnn_style = HomoGNN(operator, input_channels, hidden_channels, out_channels['style'], n_layers, dropout, skip)
        self.gnn_genre = HomoGNN(operator, input_channels, hidden_channels, out_channels['genre'], n_layers, dropout, skip)

    def forward(self, x, edge_index):
        return [self.gnn_artist(x, edge_index), self.gnn_style(x, edge_index), self.gnn_genre(x, edge_index)]

In [24]:
class ArtGraphGCNBoost:
    operator_registry = {
        'SAGEConv': operators.SAGEConv,
        'GraphConv': operators.GraphConv,
        'GATConv': operators.GATConv,
        'GCNConv': operators.GCNConv
    }

    map_id2labels = {
        0: 'artist',
        1: 'style',
        2: 'genre'
    }

    map_labels2id = {
        'artist': 0,
        'style': 1,
        'genre': 2
    }

    def __init__(self, args, training_mode='multi_task'):
        
        self.traning_mode = training_mode
        assert training_mode in ['multi_task', 'single_task']
        assert args.operator in self.operator_registry.keys()

        self.base_data, self.data, self.y, self.model, self.optimizer = self._bootstrap(args)
        self.artworks = self.base_data[0]['artwork']
        self.train_mask = self.artworks.train_mask
        self.val_mask = self.artworks.val_mask
        self.test_mask = self.artworks.test_mask

    def _bootstrap(self, args):
        base_data = ArtGraph("data", preprocess='node2vec', transform=T.ToUndirected(), features=True, type='ekg')
        data = base_data[0]
        data = data.to_homogeneous()
        if self.traning_mode == 'multi_task':
            model = HomoMGNN(operator=self.operator_registry[args.operator],
                                input_channels=base_data.num_features,
                                hidden_channels=args.hidden,
                                out_channels=base_data.num_classes,
                                n_layers=args.nlayers,
                                dropout=args.dropout,
                                skip=args.skip)
            y = torch.stack([base_data[0]['artwork'].y_artist, base_data[0]['artwork'].y_style, base_data[0]['artwork'].y_genre])
        if self.traning_mode == 'single_task':
            model = HomoSGNN(operator=self.operator_registry[args.operator],
                                input_channels=base_data.num_features,
                                hidden_channels=args.hidden,
                                out_channels=base_data.num_classes[args.label],
                                n_layers=args.nlayers,
                                dropout=args.dropout,
                                skip=args.skip)
            y = torch.stack([base_data[0]['artwork'][f'y_{args.label}']])
        
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=3e-4)
        
        return base_data, data, y, model, optimizer

    def get_accuracy(self, predicted, labels):
        return predicted.argmax(dim=1).eq(labels).sum()/predicted.shape[0]

    def get_accuracies_homo(self, predicted, labels, mask):
        size = self.train_mask.shape[0]
        accuracies = [] 
        for i, _ in enumerate(labels):
            accuracies.append(self.get_accuracy(predicted[i][:size][mask], labels[i][mask]))
        return accuracies

    def get_loss(self, predicted, labels):
        return F.nll_loss(predicted, labels.type(torch.LongTensor))
    
    def get_losses_homo(self, predicted, labels, mask):
        size = self.train_mask.shape[0]
        losses = []
        for i, _ in enumerate(labels):
            losses.append(self.get_loss(predicted[i][:size][mask], labels[i][mask]))
        return losses

    def homo_training(self):
        self.model.train()
        self.optimizer.zero_grad()
        out = self.model(self.data.x, self.data.edge_index)

        train_losses = self.get_losses_homo(out, self.y, self.train_mask)
        train_total_loss = sum(train_losses)

        train_total_loss.backward()
        self.optimizer.step()

        train_accuracies = self.get_accuracies_homo(out, self.y, self.train_mask)

        return out, train_losses, train_accuracies

    def homo_test(self, out):
        val_losses = self.get_losses_homo(out, self.y, self.val_mask)
        test_losses = self.get_losses_homo(out, self.y, self.test_mask)

        val_accuracies = self.get_accuracies_homo(out, self.y, self.val_mask)
        test_accuracies = self.get_accuracies_homo(out, self.y, self.test_mask)

        return val_losses, val_accuracies, test_losses, test_accuracies

In [33]:
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--exp', type=str, default='default', help='Experiment name.')
parser.add_argument('--type', type=str, default='homo', help='Graph type (hetero|homo).')
parser.add_argument('--mode', type=str, default='single_task', help='Training mode (multi_task|single_task).')
parser.add_argument('--label', type=str, default='artist', help='Label to predict (artist|style|genre).')
parser.add_argument('--epochs', type=int, default=1, help='Number of epochs to train.')
parser.add_argument('--lr', type=float, default=0.001, help='Initial learning rate.')
parser.add_argument('--hidden', type=int, default=16, help='Number of hidden units.')
parser.add_argument('--nlayers', type=int, default=1, help='Number of layers.')
parser.add_argument('--dropout', type=float, default=0, help='Dropout rate (1 - keep probability).')
parser.add_argument('--operator', type=str, default='GCNConv', help='The graph convolutional operator.')
parser.add_argument('--aggr', type=str, default='sum', help='Aggregation function.')
parser.add_argument('--skip', action='store_true', default='False', help='Add skip connection.')
args, unknown = parser.parse_known_args()

In [34]:
gcn = ArtGraphGCNBoost(args, training_mode=args.mode)

In [35]:
for epoch in tqdm(range(0, args.epochs)):
    out, train_losses, train_accuracies = gcn.homo_training()
    val_losses, val_accuracies, test_losses, test_accuracies = gcn.homo_test(out)
    if args.mode == 'multi_task':
        for i, train_loss_acc in enumerate(zip(train_losses, train_accuracies)):
            print(f'{gcn.map_id2labels[i]}_train_loss', round(train_loss_acc[0].detach().item(), 4))
            print(f'{gcn.map_id2labels[i]}_train_accuracy', round(train_loss_acc[1].item(), 2) * 100)
        for i, val_loss_acc in enumerate(zip(val_losses, val_accuracies)):
            print(f'{gcn.map_id2labels[i]}_val_loss', round(val_loss_acc[0].detach().item(), 4))
            print(f'{gcn.map_id2labels[i]}_val_accuracy', round(val_loss_acc[1].item(), 2) * 100)
        for i, test_loss_acc in enumerate(zip(test_losses, test_accuracies)):
            print(f'{gcn.map_id2labels[i]}_test_loss', round(test_loss_acc[0].detach().item(), 4))
            print(f'{gcn.map_id2labels[i]}_test_accuracy', round(test_loss_acc[1].item(), 2) * 100)
    if args.mode == 'single_task':
        print(f'{args.label}_train_loss', round(train_losses[0].detach().item(), 4))
        print(f'{args.label}_train_accuracy', round(train_accuracies[0].item(), 2) * 100)
        print(f'{args.label}_val_loss', round(val_losses[0].detach().item(), 4))
        print(f'{args.label}_val_accuracy', round(val_accuracies[0].item(), 2) * 100)
        print(f'{args.label}_test_loss', round(test_losses[0].detach().item(), 4))
        print(f'{args.label}_test_accuracy', round(test_accuracies[0].item(), 2) * 100)

100%|██████████| 1/1 [00:21<00:00, 21.52s/it]

artist_train_loss 5.7038
artist_train_accuracy 0.0
artist_val_loss 5.7025
artist_val_accuracy 1.0
artist_test_loss 5.7031
artist_test_accuracy 1.0



