In [35]:
import os
from pathlib import Path
import math
import json
import numpy as np
import torch
import time
import matplotlib.pyplot as plt
import importlib
from sklearn.linear_model import SGDClassifier
from toolbox import losses
from toolbox.losses import coloring_loss, triplet_loss
from toolbox import metrics
from loaders.loaders import siamese_loader
from toolbox.metrics import all_losses_acc, accuracy_linear_assignment
from toolbox.utils import check_dir
from models import coloring_model, get_siamese_model_test
from models import utils
from loaders import data_generator
importlib.reload(data_generator)
from loaders.data_generator import KCOL_Generator, QAP_Generator

In [36]:
def get_device_config(model_path):
    """ Get the same device as used for training """
    config_file = os.path.join(model_path,'config.json')
    with open(config_file) as json_file:
        config_model = json.load(json_file)
    use_cuda = not config_model['cpu'] and torch.cuda.is_available()
    device = 'cuda' if use_cuda else 'cpu'
    return config_model, device

def acc_2_error(mean_acc, q_acc):
    error = q_acc-mean_acc[:,np.newaxis]
    error[:,0] = -error[:,0]
    return error

def compute_dataset(args,path_dataset,train=True,bs=10):
    num_batches = math.ceil(args['num_examples_val']/bs)
    if train:
        gene = KCOL_Generator('train', args, path_dataset)
    else:
        gene = KCOL_Generator('test', args, path_dataset)
    gene.load_dataset()
    loader = siamese_loader(gene, bs, gene.constant_n_vertices)
    return loader

def compute_quant(all_acc,quant_low=0.1,quant_up=0.9):
    mean_acc = np.mean(all_acc,1)
    num = len(mean_acc)
    q_acc = np.zeros((num,2))
    for i in range(num):
        q_acc[i,:] = np.quantile(all_acc[i,:],[quant_up, quant_low])
    return mean_acc, q_acc

def train_epoch(model, embed_model, train_loader):
    loss_fn = coloring_loss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    model.train()
    cum_loss = 0
    for idx, (graph,tgt) in enumerate(train_loader):
        graph['input'] = graph['input'].to(device)
        embed = embed_model.node_embedder(graph)['ne/suffix']
        embed = torch.permute(embed,(0,2,1))
        
        tgt = tgt['input'].to(device)
        
        out = model(embed)
        
        optimizer.zero_grad()

        loss = loss_fn(graph['input'], out, tgt)
        loss.backward()

        optimizer.step()
        cum_loss += loss.item()
    return cum_loss / len(train_loader)


def evaluate(model, val_loader):
    model.eval()
    cum_loss = 0
    for idx, (graph, tgt) in (enumerate(valid_loader)):
        graph['input'] = graph['input'].to(device)
        embed = embed_model.node_embedder(graph)['ne/suffix']
        embed = torch.permute(embed,(0,2,1))
        
        tgt = tgt.to(device)
        
        out = model(embed)
        
        loss = loss_fn(out, tgt)
        cum_loss += loss.item()
    return cum_loss / len(val_loader) 
    

## Loading the pretrained model

In [37]:
cwd = os.getcwd()
cwd = "/".join(cwd.split("/")[:-1])

In [38]:
model_path = cwd+'/experiments-gnn/qap/expe_new/node_embedding_rec_Regular_150_0.05/05-15-23-16-14'
config_model, device = get_device_config(model_path)
embed_model = get_siamese_model_test(config_model["data"]["test"]["path_model"])
embed_model.to(device)

Siamese_Node_Exp(
  (node_embedder): Network(
    (ne_bm_in): GraphNorm()
    (ne_bm_block1_mlp3): MlpBlock_Real(
      (convs): ModuleList(
        (0): Conv2d(2, 64, kernel_size=(1, 1), stride=(1, 1))
        (1-2): 2 x Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
      )
      (gn): GraphNorm()
    )
    (ne_bm_cat1): Concat()
    (ne_bm_block2_mlp1): MlpBlock_Real(
      (convs): ModuleList(
        (0): Conv2d(66, 64, kernel_size=(1, 1), stride=(1, 1))
        (1-2): 2 x Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
      )
      (gn): GraphNorm()
    )
    (ne_bm_block2_mlp2): MlpBlock_Real(
      (convs): ModuleList(
        (0): Conv2d(66, 64, kernel_size=(1, 1), stride=(1, 1))
        (1-2): 2 x Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
      )
      (gn): GraphNorm()
    )
    (ne_bm_block2_mult): Matmul()
    (ne_bm_block2_cat): Concat()
    (ne_bm_block2_mlp3): MlpBlock_Real(
      (convs): ModuleList(
        (0): Conv2d(130, 64, kernel_size=(1, 1), stride

## Generating data and loaders

In [39]:
config_model["data"]["test"]["k"] = 5
config_model["data"]["test"]["num_examples_train"] = 2000
config_model["data"]["test"]["num_examples_val"] = 100

In [None]:
args = config_model["data"]["test"]
train_loader = compute_dataset(args, cwd+'/experiments-gnn/kcol/data')
valid_loader = compute_dataset(args, cwd+'/experiments-gnn/kcol/data', train=False)

Creating dataset at /home/mdepres/experiments-gnn/kcol/data/Color_ErdosRenyi_2000_150_1.0_0.05/train.pkl


 69%|███████████████████████████████████████████████████▍                       | 1372/2000 [04:14<01:55,  5.42it/s]

## Training a logistic regression

In [34]:
clf = SGDClassifier(loss='log_loss', warm_start=True)

for idx, (graph,tgt) in enumerate(train_loader):
    graph['input'] = graph['input'].to(device)
    embed = embed_model.node_embedder(graph)['ne/suffix']
    
    embed = embed.cpu().detach().numpy()
    embed = np.swapaxes(embed,1,2)
    embed = np.resize(embed, (embed.shape[0]*embed.shape[1],embed.shape[-1]))
    tgt = tgt['input']
    tgt = np.resize(tgt, (tgt.shape[0]*tgt.shape[1],tgt.shape[-1]))
    
    print(tgt.shape, embed.shape)
    clf.partial_fit(embed, tgt)

(1500, 5) (1500, 64)


ValueError: y should be a 1d array, got an array of shape (1500, 5) instead.

## Training the coloring model

In [None]:
model = coloring_model.ColoringModel(args['n_vertices'],embed_dim=64, k=args["k"])
model.to(device)

In [None]:
num_epochs = 30
for epoch in range(1, num_epochs+1):
    start_time = time.time()
    train_loss = train_epoch(model, embed_model, train_loader)
    end_time = time.time()
    val_loss = evaluate(model, valid_loader)
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "
          f"Epoch time = {(end_time - start_time):.3f}s"))