In [1]:
%cd ..

/home/jbananafish/Desktop/Master/Thesis/code/gcnboost


In [12]:
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, SAGEConv, Linear, GraphConv, GATConv, to_hetero
import torch_geometric.transforms as T
import torch_geometric.nn as operators

In [13]:
torch.manual_seed(1)
torch.cuda.manual_seed(1)

In [14]:
class HeteroGNN(torch.nn.Module):
    def __init__(self, operator=SAGEConv, hidden_channels=16, out_channels=300, num_layers=1, dropout=0.5, skip=False):
        super(HeteroGNN, self).__init__()
        self.dropout = dropout
        self.skip = skip
        self.convs = torch.nn.ModuleList()
        self.lins = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = operator((-1, -1), hidden_channels)
            lin = Linear(-1, hidden_channels)
            self.convs.append(conv)
            self.lins.append(lin)
        self.conv_out = operator((-1, -1), out_channels)

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs):
            if self.skip:
                x = conv(x, edge_index).relu() + self.lins[i](x)
            else:
                x = conv(x, edge_index).relu()
            x = F.dropout(x, self.dropout)
        x = self.conv_out(x, edge_index)
        return F.log_softmax(x, dim=1)

In [15]:
class HeteroSGNN(torch.nn.Module):
    def __init__(self, operator, aggr, hidden_channels, out_channels, metadata, n_layers, dropout, skip):
        super(HeteroSGNN, self).__init__()
        self.gnn = HeteroGNN(operator, hidden_channels, out_channels, n_layers, dropout, skip)
        self.gnn = to_hetero(self.gnn, metadata, aggr=aggr)

    def forward(self, x, edge_index):
        return [self.gnn(x, edge_index)]

class HeteroMGNN(torch.nn.Module):
    def __init__(self, operator, aggr, hidden_channels, out_channels, metadata, n_layers, dropout, skip):
        super(HeteroMGNN, self).__init__()
        self.gnn_artist = HeteroGNN(operator, hidden_channels, out_channels['artist'], n_layers, dropout, skip)
        self.gnn_artist = to_hetero(self.gnn_artist, metadata, aggr=aggr)

        self.gnn_style = HeteroGNN(operator, hidden_channels, out_channels['style'], n_layers, dropout, skip)
        self.gnn_style = to_hetero(self.gnn_style, metadata, aggr=aggr)

        self.gnn_genre = HeteroGNN(operator, hidden_channels, out_channels['genre'], n_layers, dropout, skip)
        self.gnn_genre = to_hetero(self.gnn_genre, metadata,aggr=aggr)

    def forward(self, x, edge_index):
        return [self.gnn_artist(x, edge_index), self.gnn_style(x, edge_index), self.gnn_genre(x, edge_index)]

In [16]:
class ArtGraphGCNBoost:

    operator_registry = {
        'SAGEConv': operators.SAGEConv,
        'GraphConv': operators.GraphConv,
        'GATConv': operators.GATConv,
        'GCNConv': operators.GCNConv
    }

    map_id2labels = {
        0: 'artist',
        1: 'style',
        2: 'genre'
    }

    map_labels2id = {
        'artist': 0,
        'style': 1,
        'genre': 2
    }

    def __init__(self, args, graph_type='hetero', training_mode='multi_task'):
        
        self.graph_type = graph_type
        self.traning_mode = training_mode
        assert graph_type in ['hetero', 'homo']
        assert training_mode in ['multi_task', 'single_task']
        assert args.operator in self.operator_registry.keys()

        self.base_data, self.data, self.y, self.model, self.optimizer = self._bootstrap(args)
        self.artworks = self.base_data[0]['artwork']
        self.train_mask = self.artworks.train_mask
        self.val_mask = self.artworks.val_mask
        self.test_mask = self.artworks.test_mask

    def _bootstrap(self, args):
        base_data = ArtGraph("../data", preprocess='node2vec', transform=T.ToUndirected(), features=True, type='ekg')
        data = base_data[0]
        if self.traning_mode == 'multi_task':
            model = HeteroMGNN(operator=self.operator_registry[args.operator], 
                            aggr=args.aggr, 
                            hidden_channels=args.hidden, 
                            out_channels=base_data.num_classes, 
                            metadata=data.metadata(),
                            n_layers=args.nlayers, 
                            dropout=args.dropout,
                            skip=args.skip)
            y = torch.stack([base_data[0]['artwork'].y_artist, base_data[0]['artwork'].y_style, base_data[0]['artwork'].y_genre])
        if self.traning_mode == 'single_task':
            model = HeteroSGNN(operator=self.operator_registry[args.operator], 
                            aggr=args.aggr, 
                            hidden_channels=args.hidden, 
                            out_channels=base_data.num_classes[args.label], 
                            metadata=data.metadata(),
                            n_layers=args.nlayers, 
                            dropout=args.dropout,
                            skip=args.skip)
            y = torch.stack([base_data[0]['artwork'][f'y_{args.label}']])

        
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=3e-4)
        
        return base_data, data, y, model, optimizer

    def get_accuracy(self, predicted, labels):
        return predicted.argmax(dim=1).eq(labels).sum()/predicted.shape[0]

    def get_accuracies(self, predicted, labels, mask):
        accuracies = [] 
        for i, _ in enumerate(labels):
            accuracies.append(self.get_accuracy(predicted[i]['artwork'][mask], labels[i][mask]))
        return accuracies

    def get_loss(self, predicted, labels):
        return F.nll_loss(predicted, labels.type(torch.LongTensor))

    def get_losses(self, predicted, labels, mask):
        losses = []
        for i, _ in enumerate(labels):
            losses.append(self.get_loss(predicted[i]['artwork'][mask], labels[i][mask]))
        return losses

    def hetero_training(self):
        self.model.train()
        self.optimizer.zero_grad()
        out = self.model(self.data.x_dict, self.data.edge_index_dict)

        train_losses = self.get_losses(out, self.y, self.train_mask)
        train_total_loss = sum(train_losses)

        train_total_loss.backward()
        self.optimizer.step()

        train_accuracies = self.get_accuracies(out, self.y, self.train_mask)

        return out, train_losses, train_accuracies
        
    def hetero_test(self, out):
        val_losses = self.get_losses(out, self.y, self.val_mask)
        test_losses = self.get_losses(out, self.y, self.test_mask)

        val_accuracies = self.get_accuracies(out, self.y, self.val_mask)
        test_accuracies = self.get_accuracies(out, self.y, self.test_mask)

        return val_losses, val_accuracies, test_losses, test_accuracies

In [20]:
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--exp', type=str, default='default', help='Experiment name.')
parser.add_argument('--type', type=str, default='hetero', help='Graph type (hetero|homo).')
parser.add_argument('--mode', type=str, default='multi_task', help='Training mode (multi_task|single_task).')
parser.add_argument('--label', type=str, default='all', help='Label to predict (artist|style|genre).')
parser.add_argument('--epochs', type=int, default=1, help='Number of epochs to train.')
parser.add_argument('--lr', type=float, default=0.001, help='Initial learning rate.')
parser.add_argument('--hidden', type=int, default=16, help='Number of hidden units.')
parser.add_argument('--nlayers', type=int, default=1, help='Number of layers.')
parser.add_argument('--dropout', type=float, default=0, help='Dropout rate (1 - keep probability).')
parser.add_argument('--operator', type=str, default='GCNConv', help='The graph convolutional operator.')
parser.add_argument('--aggr', type=str, default='sum', help='Aggregation function.')
parser.add_argument('--skip', action='store_true', default='False', help='Add skip connection.')
args, unknown = parser.parse_known_args()

In [None]:
for epoch in tqdm(range(0, args.epochs)):
    out, train_losses, train_accuracies = gcn.hetero_training()
    val_losses, val_accuracies, test_losses, test_accuracies = gcn.hetero_test(out)
    if args.mode == 'multi_task':
        for i, train_loss_acc in enumerate(zip(train_losses, train_accuracies)):
            print(f'{gcn.map_id2labels[i]}_train_loss', round(train_loss_acc[0].detach().item(), 4))
            print(f'{gcn.map_id2labels[i]}_train_accuracy', round(train_loss_acc[1].item(), 2) * 100)
        for i, val_loss_acc in enumerate(zip(val_losses, val_accuracies)):
            print(f'{gcn.map_id2labels[i]}_val_loss', round(val_loss_acc[0].detach().item(), 4))
            print(f'{gcn.map_id2labels[i]}_val_accuracy', round(val_loss_acc[1].item(), 2) * 100)
        for i, test_loss_acc in enumerate(zip(test_losses, test_accuracies)):
            print(f'{gcn.map_id2labels[i]}_test_loss', round(test_loss_acc[0].detach().item(), 4))
            print(f'{gcn.map_id2labels[i]}_test_accuracy', round(test_loss_acc[1].item(), 2) * 100)
    if args.mode == 'single_task':
        print(f'{args.label}_train_loss', round(train_losses[0].detach().item(), 4))
        print(f'{args.label}_train_accuracy', round(train_accuracies[0].item(), 2) * 100)
        print(f'{args.label}_val_loss', round(val_losses[0].detach().item(), 4))
        print(f'{args.label}_val_accuracy', round(val_accuracies[0].item(), 2) * 100)
        print(f'{args.label}_test_loss', round(test_losses[0].detach().item(), 4))
        print(f'{args.label}_test_accuracy', round(test_accuracies[0].item(), 2) * 100)