In [1]:
from torch.utils.data import Dataset
import torch
import torch.backends.cudnn as cudnn
from utils.preprocessing import WordEmbedding, load_word_emb
import numpy as np
import os
import pandas as pd
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
from utils.metrics import accuracy, precision, recall
from models.pointer_net import PointerNet
from utils.datasets import SchemaMatchingDataset

import warnings
warnings.filterwarnings("ignore")  

In [2]:
params = {
    # Data
    'batch_size': 64,
    'trainsplit': 0.8,
    'shuffle': True,
    #Train
    'nof_epoch': 150,
    'lr': 0.001,
    # GPU
    'gpu': True,
    # Network
    'input_size': 300,
    'embedding_size': 300,
    'hiddens': 256,
    'nof_lstms': 2,
    'dropout': 0.3,
    'bidir': True
}

In [3]:
dataset = SchemaMatchingDataset(None, from_path=True)
dataset.load('data/training')

In [4]:
for version in ['np', 'ap', 'sp']:
    logs = []
    # np = no pretraining
    # ap = pretraining on alphabet sorting
    # sp = pretraining on 1to1 schema pointing
    model = PointerNet(params['input_size'],
                       params['embedding_size'],
                       params['hiddens'],
                       params['nof_lstms'],
                       params['dropout'],
                       params['bidir'])

    model.initialize('serialized/schema_pointer_{}.pt'.format(version))

    if params['gpu'] and torch.cuda.is_available():
        model.cuda()
        net = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
        cudnn.benchmark = True

    CCE = torch.nn.CrossEntropyLoss()
    num_samples = 20000
    batch_size = 64
    for data in tqdm(dataset.yield_bootstrap(num_samples, batch_size), total=num_samples):
        inputs, targets = data
        if torch.cuda.is_available:
            inputs, targets = inputs.cuda(), targets.cuda()
        outputs, pointers = model(inputs)
        acc, rec, prec = accuracy(pointers, targets), recall(pointers, targets), precision(pointers, targets)
        log = {
            'accuracy': acc,
            'recall': rec,
            'precision': prec
        }
        logs.append(log)
    logs = pd.DataFrame(logs)
    logs.to_csv('logging/bootstrap_analysis_{}.txt'.format(version), index=False)

  1%|▌                                                                               | 37/5000 [00:06<13:28,  6.14it/s]


KeyboardInterrupt: 