In [1]:
# %%
import os

os.environ["TOKENIZERS_PARALLELISM"] = 'false'
os.environ["WANDB_DIR"] = '/projets/melodi/gsantoss/wandbt'

from cmatcher.owl_utils import *
from cmatcher.cqa_search import *
from cmatcher.eval_utils import *
from transformers import AutoTokenizer
import dill
import torch
import torch.optim as optim
import copy
import tqdm
from cmatcher.model import *
from tqdm.auto import tqdm
import random
import wandb

import argparse

torch.manual_seed(0)
random.seed(0)

In [5]:





def parse_arguments():
    arg_parser = argparse.ArgumentParser(description='')

    arg_parser.add_argument('--sweep', dest='sweep', nargs='?', type=int)

    return arg_parser.parse_args()


# %%


# args = parse_arguments()

test_onts = ['cmt', 'conference', 'confOf', 'edas', 'ekaw']
language_models = ['BAAI/bge-base-en', 'infgrad/stella-base-en-v2', 'BAAI/bge-large-en-v1.5', 'llmrails/ember-v1',
                   'thenlper/gte-large']
architectures = ['lm', 'gnn', 'sgnn']
lm_grad = ['none', 'grad']
pred = ['none', 'pred']
dephs = [1, 2, 3, 4]


def all_combinations():
    combs = []
    for to in test_onts:
        for lm in language_models:
            for a in architectures:
                if a == 'lm':
                    combs.append((to, lm, a, 'grad', 'none', 0))
                    continue
                for g in lm_grad:
                    for p in pred:
                        for d in dephs:
                            combs.append((to, lm, a, g, p, d))

    return combs


# test_ont, language_model, architecture, grad, cpred, depth = all_combinations()[args.sweep]
# test_ont, language_model, architecture, grad, cpred, depth = all_combinations()[0]

# config = {
#     'test_ont': test_ont,
#     'learning_rate': 0.00001,
#     'language_model': language_model,
#     'architecture': arc        print(batch)
#     'pred': cpred,
#     'epochs': 5,
#     'batch_size': 2,
#     'evm_th': 0.9,
#     'ev_sim_threshold': 0.8,
#     'sim_margin': 0.8,
#     'depth': depth,
#     'grad': grad
# }



config = {
    'test_ont': 'edas',
    'learning_rate': 0.00001,
    'language_model': 'google-bert/bert-base-uncased',
    'architecture': 'gnn',
    'pred': 'none',
    'epochs': 5,
    'batch_size': 2,
    'evm_th': 0.9,
    'ev_sim_threshold': 0.8,
    'sim_margin': 0.8,
    'depth': 2,
    'grad': 'grad'
}

cqa_path = '/projets/melodi/gsantoss/data/complex/CQAs'
entities_path = '/projets/melodi/gsantoss/data/complex/entities-cqas'
temp_path = '/projets/melodi/gsantoss/tmp'
onts_path = '/projets/melodi/gsantoss/data/oaei/tracks/conference/onts'


# cqa_path = '/home/guilherme/Documents/complex/CQAs'
# entities_path = '/home/guilherme/Documents/complex/entities-cqas'
# temp_path = '/tmp'
# onts_path = '/home/guilherme/Documents/kg/conference'



ontology_paths = {
    'edas.owl': f'{onts_path}/edas.owl',
    'ekaw.owl': f'{onts_path}/ekaw.owl',
    'confOf.owl': f'{onts_path}/confOf.owl',
    'conference.owl': f'{onts_path}/conference.owl',
    'cmt.owl': f'{onts_path}/cmt.owl',
}




test_ont = config['test_ont']


cqas = load_cqas(cqa_path)


if os.path.exists(f'{temp_path}/idata.pkl'):
    with open(f'{temp_path}/idata.pkl', 'rb') as f:
        train_ont_cqa_subg = dill.load(f)
        print('loaded from cache.')
else:
    with open(f'{temp_path}/idata.pkl', 'wb') as f:
        train_ont_cqa_subg = load_entities(entities_path, ontology_paths)
        dill.dump(load_entities(entities_path, ontology_paths), f)
        


class RawDataset:
    
    def __init__(self, temp_path, test_ont, entities_path, ontology_paths, onts_path, use_cache=True):
        
        if os.path.exists(f'{temp_path}/{test_ont}.pkl') and use_cache:
            with open(f'{temp_path}/{test_ont}.pkl', 'rb') as f:
                ifd, mc, mp, fres = dill.load(f)
                print('loaded from cache.')
        else:
            
            isg = load_sg(entities_path, ontology_paths)
            ifd, mc, mp, fres = build_raw_ts(f'{onts_path}/{test_ont}.owl', isg[test_ont], workers=4)
            with open(f'{temp_path}/{test_ont}.pkl', 'wb') as f:
                dill.dump((ifd, mc, mp, fres), f)
        
        
        self.ifd = ifd
        self.mc = mc
        self.mp = mp
        self.fres = fres
        
        
raw_dataset = RawDataset(temp_path, test_ont, entities_path, ontology_paths, onts_path)
        
ifd = raw_dataset.ifd
mc = raw_dataset.mc
mp = raw_dataset.mp
fres = raw_dataset.fres

conts_cqa_subg = copy.deepcopy(train_ont_cqa_subg)
del conts_cqa_subg[test_ont]

tokenizer = AutoTokenizer.from_pretrained(config['language_model'])

root_entities, graph_data, cq, cqid, caq, cqmask, tor = prepare_eval_dataset(test_ont, cqas, ifd, tokenizer, mc, mp, fres)

wandb.init(
    project='cmatcher',
    config=config,
    group=f'{config["language_model"]}-{config["architecture"]}-{config["pred"]}-{config["grad"]}',
    settings=wandb.Settings(_disable_stats=True, _disable_meta=True)
)

print(config)

# %%
print('start training')

# %%
model = Model(config['language_model'], d=config['depth'], lm_grad=config['grad'] == 'grad')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])

lh = []
evh = []
epochs = config['epochs']
batch_size = config['batch_size']
progress = None

triplet_loss = nn.TripletMarginWithDistanceLoss(
    distance_function=lambda x, y: 1.0 - torch.cosine_similarity(x, y, dim=1), margin=config['sim_margin'])

print('build datasets')

del train_ont_cqa_subg['confOf']
# del train_ont_cqa_subg['conference']
# del train_ont_cqa_subg['edas']
del train_ont_cqa_subg['ekaw']




raw_data = build_raw_data(train_ont_cqa_subg, cqas)


dataset = CQADataset(tokenizer, conts_cqa_subg, raw_data[test_ont], filter_bn=False)
loader = DataLoader(dataset, batch_size=batch_size)

cqloader = DataLoader(cqid, batch_size=batch_size, shuffle=False)
acqloader = [DataLoader(a, batch_size=batch_size, shuffle=False) for a in caq]
graph_loader = DataLoader(graph_data, batch_size=batch_size, shuffle=False)


ERROR! Session/line number was not unique in database. History logging moved to new session 1131


  0%|          | 0/101 [00:00<?, ?it/s]

loaded from cache.
loaded from cache.


{'test_ont': 'edas', 'learning_rate': 1e-05, 'language_model': 'google-bert/bert-base-uncased', 'architecture': 'gnn', 'pred': 'none', 'epochs': 5, 'batch_size': 2, 'evm_th': 0.9, 'ev_sim_threshold': 0.8, 'sim_margin': 0.8, 'depth': 2, 'grad': 'grad'}
start training
build datasets


In [6]:
print(len(dataset))

328


In [7]:

print('data prepared')
model.find_unused_parameters = False
if not progress:
    progress = tqdm(total=epochs * len(loader))

print('start training')
# evh.append(evm(model, dataset, device, th=config["ev_sim_threshold"]))
# eval_test(model, device, cqloader, graph_loader, cq, root_entities, fres, acqloader, cqmask, tor)
# wandb.log({'global/acc': evh[-1]})

for e in range(epochs):

    model.train()

    el = []
    for batch in loader:
        
        optimizer.zero_grad()
        
        cqs, sbgs, nsbg = model(cqa=batch.cqs.to(device), positive_sbg=(batch.x_sf.to(device), batch.x_s.to(device),
                                                                     batch.edge_index_s.to(device),
                                                                     batch.edge_feat_sf.to(device),
                                                                     batch.edge_feat_s.to(device)),
                                negative_sbg=(batch.x_nf.to(device), batch.x_n.to(device),
                                              batch.edge_index_n.to(device), batch.edge_feat_nf.to(device),
                                              batch.edge_feat_n.to(device)))
        

        isbgs = sbgs[batch.rsi]
        isbgn = nsbg[batch.rni]


        loss = triplet_loss(cqs, isbgs, isbgn)
        el.append(loss.detach())
        loss.backward()

        optimizer.step()
        progress.update(1)


    lh.append(torch.stack(el).mean().item())

    evh.append(evm(model, dataset, device, th=config["ev_sim_threshold"]))
    eval_test(model, device, cqloader, graph_loader, cq, root_entities, fres, acqloader, cqmask, tor)
    wandb.log({'global/acc': evh[-1], 'global/loss': lh[-1]})

progress.close()

wandb.finish()

torch.save(model.state_dict(), f'{temp_path}/model.pt')
# %%

data prepared


  0%|          | 0/820 [00:00<?, ?it/s]

start training
begin evm
end evm
begin evm
end evm
begin evm
end evm
begin evm
end evm
begin evm
end evm




VBox(children=(Label(value='0.006 MB of 0.006 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
each/cmt-afm,▅▄█▁▆
each/cmt-avgp,▃▂▃▁█
each/cmt-bt,▁▅▇██
each/cmt-rec,▂█▇▂▁
each/confOf-afm,▃▁▂█▁
each/confOf-avgp,▂▁▁█▁
each/confOf-bt,▄▆▁█▂
each/confOf-rec,▅█▁▅▅
each/conference-afm,█▃█▁▆
each/conference-avgp,█▃█▁▄

0,1
each/cmt-afm,0.16441
each/cmt-avgp,0.46193
each/cmt-bt,0.85
each/cmt-rec,0.1
each/confOf-afm,0.01214
each/confOf-avgp,0.0063
each/confOf-bt,0.35
each/confOf-rec,0.16667
each/conference-afm,0.13956
each/conference-avgp,0.12771
