In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
import sys
import logging

# Configure logging for Fonduer
logging.basicConfig(stream=sys.stdout, format='[%(levelname)s] %(name)s:%(lineno)s - %(message)s')
logger = logging.getLogger('fonduer')
logger.setLevel(logging.INFO)

PARALLEL = 16 # assuming a quad-core machine
ATTRIBUTE = "circular_connectors"
conn_string = 'postgresql://localhost:5432/' + ATTRIBUTE

In [2]:
# If you've run this before, set FIRST_TIME to False to save time
FIRST_TIME = False

In [3]:
from fonduer import Meta

session = Meta.init(conn_string).Session()

[INFO] fonduer.meta:86 - Connecting user:None to localhost:5432/circular_connectors
[INFO] fonduer.meta:110 - Initializing the storage schema


In [5]:
from hack.utils import parse_dataset

dirname = "."

docs, train_docs, dev_docs, test_docs = parse_dataset(
    session, dirname, first_time=FIRST_TIME, parallel=PARALLEL, max_docs=100
)
logger.info(f"# of train Documents: {len(train_docs)}")
logger.info(f"# of dev Documents: {len(dev_docs)}")
logger.info(f"# of test Documents: {len(test_docs)}")

[INFO] fonduer:8 - # of train Documents: 100
[INFO] fonduer:9 - # of dev Documents: 100
[INFO] fonduer:10 - # of test Documents: 98


In [6]:
from fonduer.parser.models import Document, Section, Paragraph, Sentence, Figure

logger.info(f"Documents: {session.query(Document).count()}")
logger.info(f"Sections: {session.query(Section).count()}")
logger.info(f"Paragraphs: {session.query(Paragraph).count()}")
logger.info(f"Sentences: {session.query(Sentence).count()}")
logger.info(f"Figures: {session.query(Figure).count()}")

[INFO] fonduer:3 - Documents: 298
[INFO] fonduer:4 - Sections: 298
[INFO] fonduer:5 - Paragraphs: 330839
[INFO] fonduer:6 - Sentences: 341046
[INFO] fonduer:7 - Figures: 21269


In [8]:
from fonduer.candidates.models import mention_subclass

Thumbnails = mention_subclass("Thumbnails")

In [9]:
from fonduer.candidates import MentionFigures

thumbnails_img = MentionFigures()

In [52]:
from fonduer.candidates.matchers import _Matcher

In [108]:
from PIL import Image

class HasFigures(_Matcher):   
    def _f(self, m):
        file_path = ""
        for prefix in ["data/train/html/", "data/dev/html/", "data/test/html/"]:
            if os.path.exists(prefix + m.figure.url):
                file_path = prefix + m.figure.url
        if file_path == "":
            return False
        img = Image.open(file_path)
        width, height = img.size
        min_value = min(width, height)
        return min_value > 50

In [109]:
from fonduer.candidates import MentionExtractor
from fonduer.candidates.matchers import DoNothingMatcher

mention_extractor = MentionExtractor(
    session, [Thumbnails], [thumbnails_img], [HasFigures()], parallelism=PARALLEL
)

from fonduer.candidates.models import Mention

if FIRST_TIME:
    mention_extractor.apply(docs)

logger.info("Total Mentions: {}".format(session.query(Mention).count()))

[INFO] fonduer.candidates.mentions:460 - Clearing table: thumbnails
[INFO] fonduer.candidates.mentions:468 - Cascading to clear table: thumbnail_label
[INFO] fonduer.utils.udf:57 - Running UDF...


HBox(children=(IntProgress(value=0, max=298), HTML(value='')))

[INFO] fonduer:13 - Total Mentions: 8917


In [110]:
from fonduer.candidates.models import candidate_subclass

ThumbnailLabel = candidate_subclass("ThumbnailLabel", [Thumbnails])

In [111]:
from fonduer.candidates import CandidateExtractor

candidate_extractor = CandidateExtractor(
    session, [ThumbnailLabel], throttlers=[None], parallelism=PARALLEL
)

if FIRST_TIME or True:
    candidate_extractor.apply(train_docs, split=0)
    candidate_extractor.apply(dev_docs, split=1)
    candidate_extractor.apply(test_docs, split=2)

[INFO] fonduer.candidates.candidates:125 - Clearing table thumbnail_label (split 0)
[INFO] fonduer.utils.udf:57 - Running UDF...


HBox(children=(IntProgress(value=0), HTML(value='')))

[INFO] fonduer.candidates.candidates:125 - Clearing table thumbnail_label (split 1)
[INFO] fonduer.utils.udf:57 - Running UDF...


HBox(children=(IntProgress(value=0), HTML(value='')))

[INFO] fonduer.candidates.candidates:125 - Clearing table thumbnail_label (split 2)
[INFO] fonduer.utils.udf:57 - Running UDF...


HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

In [112]:
train_cands = candidate_extractor.get_candidates(split=0)
dev_cands = candidate_extractor.get_candidates(split=1)
test_cands = candidate_extractor.get_candidates(split=2)

In [113]:
logger.info("Total train candidate:\t{}".format(len(train_cands[0])))
logger.info("Total dev candidate:\t{}".format(len(dev_cands[0])))
logger.info("Total test candidate:\t{}".format(len(test_cands[0])))

[INFO] fonduer:1 - Total train candidate:	7256
[INFO] fonduer:2 - Total dev candidate:	453
[INFO] fonduer:3 - Total test candidate:	1208


In [114]:
fin = open("data/ground_truth.txt", "r")
gt = set()
for line in fin:
    gt.add("::".join(line.lower().split()))
fin.close()
# gt

In [115]:
TRUE=1
FALSE=2
ABSTAIN=0

In [116]:
def LF_gt_label(c):
    doc_file_id = f"{c[0].context.figure.document.name.lower()}.pdf::{os.path.basename(c[0].context.figure.url.lower())}"
#     print(doc_file_id)
    return TRUE if doc_file_id in gt else FALSE

In [117]:
ans = {0:0, 1:0, 2:0}

gt_dev_pb = []
gt_dev = []
gt_test = []

In [118]:
for cand in dev_cands[0]:
    if LF_gt_label(cand) == 1:
        ans[1] += 1
        gt_dev_pb.append([1., 0.])
        gt_dev.append(1.)
    else:
        ans[2] += 1
        gt_dev_pb.append([0., 1.])
        gt_dev.append(2.)

In [119]:
ans

{0: 0, 1: 69, 2: 384}

In [120]:
ans = {0:0, 1:0, 2:0}

In [121]:
for cand in test_cands[0]:
    gt_test.append(LF_gt_label(cand))
    ans[gt_test[-1]] += 1

In [122]:
ans

{0: 0, 1: 160, 2: 1048}

In [123]:
batch_size = 64
input_size = 224

In [124]:
from disc_model.torchnet import *
from utils import *

In [33]:
train_loader = torch.utils.data.DataLoader(
    ImageList(
        data=all_cands,
        label=torch.Tensor(all_label),
#         label=torch.Tensor(gt_dev_pb),
        transform=transform(input_size),
        prefix="data/dev/html/",
    ),
    batch_size=batch_size,
    shuffle = True,
#     sampler = sampler
)

In [126]:
train_loader = torch.utils.data.DataLoader(
    ImageList(
        data=dev_cands[0],
#         label=torch.Tensor(gt_dev),
        label=torch.Tensor(gt_dev_pb),
        transform=transform(input_size),
        prefix="data/dev/html/",
    ),
    batch_size=batch_size,
    shuffle = False,
#     sampler = sampler
)

In [127]:
dev_loader = torch.utils.data.DataLoader(
    ImageList(
        data=dev_cands[0],
#         label=torch.Tensor(gt_dev),
        label=gt_dev,
        transform=transform(input_size),
        prefix="data/dev/html/",
    ),
    batch_size=batch_size,
    shuffle=False,
)

In [128]:
test_loader = torch.utils.data.DataLoader(
    ImageList(
        data=test_cands[0],
        label=gt_test,
        transform=transform(input_size),
        prefix="data/test/html/",
    ),
    batch_size=batch_size,
    shuffle=False,
)

In [101]:
from metal import EndModel

In [102]:
em_config = {
    # GENERAL
    "seed": None,
    "verbose": True,
    "show_plots": True,
    # Network
    # The first value is the output dim of the input module (or the sum of
    # the output dims of all the input modules if multitask=True and
    # multiple input modules are provided). The last value is the
    # output dim of the head layer (i.e., the cardinality of the
    # classification task). The remaining values are the output dims of
    # middle layers (if any). The number of middle layers will be inferred
    # from this list.
    #     "layer_out_dims": [10, 2],
    # Input layer configs
    "input_layer_config": {
        "input_relu": False,
        "input_batchnorm": False,
        "input_dropout": 0.0,
    },
    # Middle layer configs
    "middle_layer_config": {
        "middle_relu": False,
        "middle_batchnorm": False,
        "middle_dropout": 0.0,
    },
    # Can optionally skip the head layer completely, for e.g. running baseline
    # models...
    "skip_head": True,
    # GPU
    "use_cuda": True,
    # MODEL CLASS
    "resnet18"
    # DATA CONFIG
    "src": "gm",
    # TRAINING
    "train_config": {
        # Display
        "print_every": 1,  # Print after this many epochs
        "disable_prog_bar": False,  # Disable progress bar each epoch
        # Dataloader
        "data_loader_config": {"batch_size": 32, "num_workers": 8, "sampler": None},
        # Loss weights
        "loss_weights": [0.5, 0.5],
        # Train Loop
        "n_epochs": 20,
        # 'grad_clip': 0.0,
        "l2": 0.0,
        # "lr": 0.01,
        "validation_metric": "accuracy",
        "validation_freq": 1,
        # Evaluate dev for during training every this many epochs
        # Optimizer
        "optimizer_config": {
            "optimizer": "adam",
            "optimizer_common": {"lr": 0.01},
            # Optimizer - SGD
            "sgd_config": {"momentum": 0.9},
            # Optimizer - Adam
            "adam_config": {"betas": (0.9, 0.999)},
        },
        # Scheduler
        "scheduler_config": {
            "scheduler": "reduce_on_plateau",
            # ['constant', 'exponential', 'reduce_on_plateu']
            # Freeze learning rate initially this many epochs
            "lr_freeze": 0,
            # Scheduler - exponential
            "exponential_config": {"gamma": 0.9},  # decay rate
            # Scheduler - reduce_on_plateau
            "plateau_config": {
                "factor": 0.5,
                "patience": 1,
                "threshold": 0.0001,
                "min_lr": 1e-5,
            },
        },
        # Checkpointer
        "checkpoint": True,
        "checkpoint_config": {
            "checkpoint_min": -1,
            # The initial best score to beat to merit checkpointing
            "checkpoint_runway": 0,
            # Don't start taking checkpoints until after this many epochs
        },
    },
}

In [197]:
from metal.tuners import RandomSearchTuner
from metal.contrib.logging.tensorboard import TensorBoardWriter

log_config = {"log_dir": "./run_logs", "run_name": "image"}

tuner_config = {"max_search": 1}
search_space = {
    "l2": [0.001, 0.0001, 0.00001],  # linear range
    "lr": {"range": [0.0001, 0.1], "scale": "log"},  # log range
}


train_config = em_config["train_config"]


# Defining network parameters
num_classes = 2
fc_size = 2
hidden_size = 2
pretrained = True

# Set CUDA device
torch.cuda.set_device(1)
# os.environ['CUDA_VISIBLE_DEVICES']='0'

# Initializing input module
input_module = get_cnn("resnet18", pretrained=pretrained, num_classes=num_classes)


# Initializing model object
init_args = [[num_classes]]
init_kwargs = {"input_module": input_module}
init_kwargs.update(em_config)

# init_kwargs.update(em_config)
# max_search = tuner_config['max_search']
# metric = train_config['validation_metric']

# Training model as a single pass
# if args.single_pass:
# end_model = EndModel(
#    [hidden_size, fc_size, num_classes],
#    input_module=input_module,use_cuda='True'
#     **em_config
# )
# end_model.train_model(
#    train_data=train_loader,
#    dev_data=dev_loader,
#     **train_config
# )

# Searching model
# else:
searcher = RandomSearchTuner(EndModel, **log_config)

end_model = searcher.search(
    search_space,
    dev_loader,
    train_args=[train_loader],
    init_args=init_args,
    init_kwargs=init_kwargs,
    train_kwargs=train_config,
    max_search=tuner_config["max_search"],
)

# Evaluating model
scores = end_model.score(
    test_loader, metric=["accuracy", "precision", "recall", "f1", "roc-auc"]
)

labels, _, probs = end_model._get_predictions(test_loader, return_probs=True)



  0%|          | 0/8 [00:00<?, ?it/s][A[A

Using class weight vector [0.5, 0.5]...

Network architecture:
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_ru



  0%|          | 0/8 [00:00<?, ?it/s, avg_loss=0.405][A[A

 12%|█▎        | 1/8 [00:00<00:03,  1.96it/s, avg_loss=0.405][A[A

 12%|█▎        | 1/8 [00:01<00:03,  1.96it/s, avg_loss=0.388][A[A

 25%|██▌       | 2/8 [00:01<00:03,  1.85it/s, avg_loss=0.388][A[A

 25%|██▌       | 2/8 [00:01<00:03,  1.85it/s, avg_loss=0.355][A[A

 38%|███▊      | 3/8 [00:01<00:02,  1.79it/s, avg_loss=0.355][A[A

 38%|███▊      | 3/8 [00:02<00:02,  1.79it/s, avg_loss=0.33] [A[A

 50%|█████     | 4/8 [00:02<00:02,  1.48it/s, avg_loss=0.33][A[A

 50%|█████     | 4/8 [00:03<00:02,  1.48it/s, avg_loss=0.295][A[A

 62%|██████▎   | 5/8 [00:03<00:01,  1.51it/s, avg_loss=0.295][A[A

 62%|██████▎   | 5/8 [00:03<00:01,  1.51it/s, avg_loss=0.276][A[A

 75%|███████▌  | 6/8 [00:03<00:01,  1.58it/s, avg_loss=0.276][A[A

 75%|███████▌  | 6/8 [00:04<00:01,  1.58it/s, avg_loss=0.262][A[A

 88%|████████▊ | 7/8 [00:04<00:00,  1.62it/s, avg_loss=0.262][A[A

 88%|████████▊ | 7/8 [00:04<00:00,  1.62i

Saving model at iteration 0 with best score 0.848
[E:0]	Train Loss: 0.261	Dev score: 0.848




  0%|          | 0/8 [00:00<?, ?it/s, avg_loss=2.08][A[A

 12%|█▎        | 1/8 [00:00<00:03,  1.86it/s, avg_loss=2.08][A[A

 12%|█▎        | 1/8 [00:01<00:03,  1.86it/s, avg_loss=1.14][A[A

 25%|██▌       | 2/8 [00:01<00:03,  1.80it/s, avg_loss=1.14][A[A

 25%|██▌       | 2/8 [00:01<00:03,  1.80it/s, avg_loss=0.879][A[A

 38%|███▊      | 3/8 [00:01<00:02,  1.78it/s, avg_loss=0.879][A[A

 38%|███▊      | 3/8 [00:02<00:02,  1.78it/s, avg_loss=0.709][A[A

 50%|█████     | 4/8 [00:02<00:02,  1.57it/s, avg_loss=0.709][A[A

 50%|█████     | 4/8 [00:03<00:02,  1.57it/s, avg_loss=0.644][A[A

 62%|██████▎   | 5/8 [00:03<00:01,  1.58it/s, avg_loss=0.644][A[A

 62%|██████▎   | 5/8 [00:03<00:01,  1.58it/s, avg_loss=0.584][A[A

 75%|███████▌  | 6/8 [00:03<00:01,  1.63it/s, avg_loss=0.584][A[A

 75%|███████▌  | 6/8 [00:04<00:01,  1.63it/s, avg_loss=0.831][A[A

 88%|████████▊ | 7/8 [00:04<00:00,  1.64it/s, avg_loss=0.831][A[A

 88%|████████▊ | 7/8 [00:04<00:00,  1.64it/s

[E:1]	Train Loss: 0.826	Dev score: 0.848




  0%|          | 0/8 [00:00<?, ?it/s, avg_loss=0.313][A[A

 12%|█▎        | 1/8 [00:00<00:03,  2.07it/s, avg_loss=0.313][A[A

 12%|█▎        | 1/8 [00:01<00:03,  2.07it/s, avg_loss=0.271][A[A

 25%|██▌       | 2/8 [00:01<00:03,  1.99it/s, avg_loss=0.271][A[A

 25%|██▌       | 2/8 [00:01<00:03,  1.99it/s, avg_loss=0.252][A[A

 38%|███▊      | 3/8 [00:01<00:02,  1.89it/s, avg_loss=0.252][A[A

 38%|███▊      | 3/8 [00:02<00:02,  1.89it/s, avg_loss=0.239][A[A

 50%|█████     | 4/8 [00:02<00:02,  1.62it/s, avg_loss=0.239][A[A

 50%|█████     | 4/8 [00:03<00:02,  1.62it/s, avg_loss=0.23] [A[A

 62%|██████▎   | 5/8 [00:03<00:01,  1.60it/s, avg_loss=0.23][A[A

 62%|██████▎   | 5/8 [00:03<00:01,  1.60it/s, avg_loss=0.223][A[A

 75%|███████▌  | 6/8 [00:03<00:01,  1.64it/s, avg_loss=0.223][A[A

 75%|███████▌  | 6/8 [00:04<00:01,  1.64it/s, avg_loss=0.22] [A[A

 88%|████████▊ | 7/8 [00:04<00:00,  1.58it/s, avg_loss=0.22][A[A

 88%|████████▊ | 7/8 [00:04<00:00,  1.58it

[E:2]	Train Loss: 0.221	Dev score: 0.848




  0%|          | 0/8 [00:00<?, ?it/s, avg_loss=0.26][A[A

 12%|█▎        | 1/8 [00:00<00:03,  1.83it/s, avg_loss=0.26][A[A

 12%|█▎        | 1/8 [00:01<00:03,  1.83it/s, avg_loss=0.236][A[A

 25%|██▌       | 2/8 [00:01<00:03,  1.77it/s, avg_loss=0.236][A[A

 25%|██▌       | 2/8 [00:01<00:03,  1.77it/s, avg_loss=0.229][A[A

 38%|███▊      | 3/8 [00:01<00:02,  1.69it/s, avg_loss=0.229][A[A

 38%|███▊      | 3/8 [00:02<00:02,  1.69it/s, avg_loss=0.222][A[A

 50%|█████     | 4/8 [00:02<00:02,  1.51it/s, avg_loss=0.222][A[A

 50%|█████     | 4/8 [00:03<00:02,  1.51it/s, avg_loss=0.213][A[A

 62%|██████▎   | 5/8 [00:03<00:02,  1.43it/s, avg_loss=0.213][A[A

 62%|██████▎   | 5/8 [00:03<00:02,  1.43it/s, avg_loss=0.209][A[A

 75%|███████▌  | 6/8 [00:03<00:01,  1.51it/s, avg_loss=0.209][A[A

 75%|███████▌  | 6/8 [00:04<00:01,  1.51it/s, avg_loss=0.208][A[A

 88%|████████▊ | 7/8 [00:04<00:00,  1.58it/s, avg_loss=0.208][A[A

 88%|████████▊ | 7/8 [00:04<00:00,  1.58it

[E:3]	Train Loss: 0.209	Dev score: 0.848




  0%|          | 0/8 [00:00<?, ?it/s, avg_loss=0.283][A[A

 12%|█▎        | 1/8 [00:00<00:03,  2.00it/s, avg_loss=0.283][A[A

 12%|█▎        | 1/8 [00:01<00:03,  2.00it/s, avg_loss=0.25] [A[A

 25%|██▌       | 2/8 [00:01<00:03,  1.70it/s, avg_loss=0.25][A[A

 25%|██▌       | 2/8 [00:01<00:03,  1.70it/s, avg_loss=0.238][A[A

 38%|███▊      | 3/8 [00:01<00:03,  1.65it/s, avg_loss=0.238][A[A

 38%|███▊      | 3/8 [00:02<00:03,  1.65it/s, avg_loss=0.228][A[A

 50%|█████     | 4/8 [00:02<00:02,  1.49it/s, avg_loss=0.228][A[A

 50%|█████     | 4/8 [00:03<00:02,  1.49it/s, avg_loss=0.217][A[A

 62%|██████▎   | 5/8 [00:03<00:01,  1.52it/s, avg_loss=0.217][A[A

 62%|██████▎   | 5/8 [00:03<00:01,  1.52it/s, avg_loss=0.212][A[A

 75%|███████▌  | 6/8 [00:03<00:01,  1.59it/s, avg_loss=0.212][A[A

 75%|███████▌  | 6/8 [00:04<00:01,  1.59it/s, avg_loss=0.211][A[A

 88%|████████▊ | 7/8 [00:04<00:00,  1.65it/s, avg_loss=0.211][A[A

 88%|████████▊ | 7/8 [00:04<00:00,  1.65i

[E:4]	Train Loss: 0.211	Dev score: 0.848




  0%|          | 0/8 [00:00<?, ?it/s, avg_loss=0.262][A[A

 12%|█▎        | 1/8 [00:00<00:03,  2.09it/s, avg_loss=0.262][A[A

 12%|█▎        | 1/8 [00:01<00:03,  2.09it/s, avg_loss=0.237][A[A

 25%|██▌       | 2/8 [00:01<00:02,  2.01it/s, avg_loss=0.237][A[A

 25%|██▌       | 2/8 [00:01<00:02,  2.01it/s, avg_loss=0.229][A[A

 38%|███▊      | 3/8 [00:01<00:02,  1.93it/s, avg_loss=0.229][A[A

 38%|███▊      | 3/8 [00:02<00:02,  1.93it/s, avg_loss=0.22] [A[A

 50%|█████     | 4/8 [00:02<00:02,  1.65it/s, avg_loss=0.22][A[A

 50%|█████     | 4/8 [00:03<00:02,  1.65it/s, avg_loss=0.211][A[A

 62%|██████▎   | 5/8 [00:03<00:01,  1.62it/s, avg_loss=0.211][A[A

 62%|██████▎   | 5/8 [00:03<00:01,  1.62it/s, avg_loss=0.207][A[A

 75%|███████▌  | 6/8 [00:03<00:01,  1.66it/s, avg_loss=0.207][A[A

 75%|███████▌  | 6/8 [00:04<00:01,  1.66it/s, avg_loss=0.206][A[A

 88%|████████▊ | 7/8 [00:04<00:00,  1.71it/s, avg_loss=0.206][A[A

 88%|████████▊ | 7/8 [00:04<00:00,  1.71i

[E:5]	Train Loss: 0.206	Dev score: 0.848




  0%|          | 0/8 [00:00<?, ?it/s, avg_loss=0.256][A[A

 12%|█▎        | 1/8 [00:00<00:03,  2.08it/s, avg_loss=0.256][A[A

 12%|█▎        | 1/8 [00:01<00:03,  2.08it/s, avg_loss=0.232][A[A

 25%|██▌       | 2/8 [00:01<00:02,  2.00it/s, avg_loss=0.232][A[A

 25%|██▌       | 2/8 [00:01<00:02,  2.00it/s, avg_loss=0.224][A[A

 38%|███▊      | 3/8 [00:01<00:02,  1.92it/s, avg_loss=0.224][A[A

 38%|███▊      | 3/8 [00:02<00:02,  1.92it/s, avg_loss=0.217][A[A

 50%|█████     | 4/8 [00:02<00:02,  1.65it/s, avg_loss=0.217][A[A

 50%|█████     | 4/8 [00:03<00:02,  1.65it/s, avg_loss=0.206][A[A

 62%|██████▎   | 5/8 [00:03<00:01,  1.63it/s, avg_loss=0.206][A[A

 62%|██████▎   | 5/8 [00:03<00:01,  1.63it/s, avg_loss=0.202][A[A

 75%|███████▌  | 6/8 [00:03<00:01,  1.67it/s, avg_loss=0.202][A[A

 75%|███████▌  | 6/8 [00:04<00:01,  1.67it/s, avg_loss=0.2]  [A[A

 88%|████████▊ | 7/8 [00:04<00:00,  1.71it/s, avg_loss=0.2][A[A

 88%|████████▊ | 7/8 [00:04<00:00,  1.71it

[E:6]	Train Loss: 0.201	Dev score: 0.848




  0%|          | 0/8 [00:00<?, ?it/s, avg_loss=0.25][A[A

 12%|█▎        | 1/8 [00:00<00:03,  1.98it/s, avg_loss=0.25][A[A

 12%|█▎        | 1/8 [00:01<00:03,  1.98it/s, avg_loss=0.227][A[A

 25%|██▌       | 2/8 [00:01<00:03,  1.89it/s, avg_loss=0.227][A[A

 25%|██▌       | 2/8 [00:01<00:03,  1.89it/s, avg_loss=0.218][A[A

 38%|███▊      | 3/8 [00:01<00:02,  1.84it/s, avg_loss=0.218][A[A

 38%|███▊      | 3/8 [00:02<00:02,  1.84it/s, avg_loss=0.211][A[A

 50%|█████     | 4/8 [00:02<00:02,  1.60it/s, avg_loss=0.211][A[A

 50%|█████     | 4/8 [00:03<00:02,  1.60it/s, avg_loss=0.2]  [A[A

 62%|██████▎   | 5/8 [00:03<00:01,  1.59it/s, avg_loss=0.2][A[A

 62%|██████▎   | 5/8 [00:03<00:01,  1.59it/s, avg_loss=0.196][A[A

 75%|███████▌  | 6/8 [00:03<00:01,  1.63it/s, avg_loss=0.196][A[A

 75%|███████▌  | 6/8 [00:04<00:01,  1.63it/s, avg_loss=0.195][A[A

 88%|████████▊ | 7/8 [00:04<00:00,  1.68it/s, avg_loss=0.195][A[A

 88%|████████▊ | 7/8 [00:04<00:00,  1.68it/s

[E:7]	Train Loss: 0.195	Dev score: 0.848




  0%|          | 0/8 [00:00<?, ?it/s, avg_loss=0.241][A[A

 12%|█▎        | 1/8 [00:00<00:03,  2.09it/s, avg_loss=0.241][A[A

 12%|█▎        | 1/8 [00:01<00:03,  2.09it/s, avg_loss=0.218][A[A

 25%|██▌       | 2/8 [00:01<00:02,  2.00it/s, avg_loss=0.218][A[A

 25%|██▌       | 2/8 [00:01<00:02,  2.00it/s, avg_loss=0.209][A[A

 38%|███▊      | 3/8 [00:01<00:02,  1.84it/s, avg_loss=0.209][A[A

 38%|███▊      | 3/8 [00:02<00:02,  1.84it/s, avg_loss=0.203][A[A

 50%|█████     | 4/8 [00:02<00:02,  1.60it/s, avg_loss=0.203][A[A

 50%|█████     | 4/8 [00:03<00:02,  1.60it/s, avg_loss=0.191][A[A

 62%|██████▎   | 5/8 [00:03<00:01,  1.60it/s, avg_loss=0.191][A[A

 62%|██████▎   | 5/8 [00:03<00:01,  1.60it/s, avg_loss=0.187][A[A

 75%|███████▌  | 6/8 [00:03<00:01,  1.64it/s, avg_loss=0.187][A[A

 75%|███████▌  | 6/8 [00:04<00:01,  1.64it/s, avg_loss=0.185][A[A

 88%|████████▊ | 7/8 [00:04<00:00,  1.69it/s, avg_loss=0.185][A[A

 88%|████████▊ | 7/8 [00:04<00:00,  1.69

[E:8]	Train Loss: 0.185	Dev score: 0.848




  0%|          | 0/8 [00:00<?, ?it/s, avg_loss=0.23][A[A

 12%|█▎        | 1/8 [00:00<00:04,  1.70it/s, avg_loss=0.23][A[A

 12%|█▎        | 1/8 [00:01<00:04,  1.70it/s, avg_loss=0.207][A[A

 25%|██▌       | 2/8 [00:01<00:03,  1.71it/s, avg_loss=0.207][A[A

 25%|██▌       | 2/8 [00:01<00:03,  1.71it/s, avg_loss=0.198][A[A

 38%|███▊      | 3/8 [00:01<00:02,  1.69it/s, avg_loss=0.198][A[A

 38%|███▊      | 3/8 [00:02<00:02,  1.69it/s, avg_loss=0.191][A[A

 50%|█████     | 4/8 [00:02<00:02,  1.50it/s, avg_loss=0.191][A[A

 50%|█████     | 4/8 [00:03<00:02,  1.50it/s, avg_loss=0.181][A[A

 62%|██████▎   | 5/8 [00:03<00:01,  1.51it/s, avg_loss=0.181][A[A

 62%|██████▎   | 5/8 [00:03<00:01,  1.51it/s, avg_loss=0.176][A[A

 75%|███████▌  | 6/8 [00:03<00:01,  1.56it/s, avg_loss=0.176][A[A

 75%|███████▌  | 6/8 [00:04<00:01,  1.56it/s, avg_loss=0.174][A[A

 88%|████████▊ | 7/8 [00:04<00:00,  1.61it/s, avg_loss=0.174][A[A

 88%|████████▊ | 7/8 [00:04<00:00,  1.61it

[E:9]	Train Loss: 0.173	Dev score: 0.848




  0%|          | 0/8 [00:00<?, ?it/s, avg_loss=0.232][A[A

 12%|█▎        | 1/8 [00:00<00:03,  1.98it/s, avg_loss=0.232][A[A

 12%|█▎        | 1/8 [00:01<00:03,  1.98it/s, avg_loss=0.206][A[A

 25%|██▌       | 2/8 [00:01<00:03,  1.88it/s, avg_loss=0.206][A[A

 25%|██▌       | 2/8 [00:01<00:03,  1.88it/s, avg_loss=0.195][A[A

 38%|███▊      | 3/8 [00:01<00:02,  1.81it/s, avg_loss=0.195][A[A

 38%|███▊      | 3/8 [00:02<00:02,  1.81it/s, avg_loss=0.188][A[A

 50%|█████     | 4/8 [00:02<00:02,  1.57it/s, avg_loss=0.188][A[A

 50%|█████     | 4/8 [00:03<00:02,  1.57it/s, avg_loss=0.176][A[A

 62%|██████▎   | 5/8 [00:03<00:01,  1.56it/s, avg_loss=0.176][A[A

 62%|██████▎   | 5/8 [00:03<00:01,  1.56it/s, avg_loss=0.171][A[A

 75%|███████▌  | 6/8 [00:03<00:01,  1.60it/s, avg_loss=0.171][A[A

 75%|███████▌  | 6/8 [00:04<00:01,  1.60it/s, avg_loss=0.168][A[A

 88%|████████▊ | 7/8 [00:04<00:00,  1.62it/s, avg_loss=0.168][A[A

 88%|████████▊ | 7/8 [00:04<00:00,  1.62

[E:10]	Train Loss: 0.168	Dev score: 0.848




  0%|          | 0/8 [00:00<?, ?it/s, avg_loss=0.214][A[A

 12%|█▎        | 1/8 [00:00<00:03,  2.03it/s, avg_loss=0.214][A[A

 12%|█▎        | 1/8 [00:01<00:03,  2.03it/s, avg_loss=0.193][A[A

 25%|██▌       | 2/8 [00:01<00:03,  1.97it/s, avg_loss=0.193][A[A

 25%|██▌       | 2/8 [00:01<00:03,  1.97it/s, avg_loss=0.184][A[A

 38%|███▊      | 3/8 [00:01<00:02,  1.89it/s, avg_loss=0.184][A[A

 38%|███▊      | 3/8 [00:02<00:02,  1.89it/s, avg_loss=0.179][A[A

 50%|█████     | 4/8 [00:02<00:02,  1.44it/s, avg_loss=0.179][A[A

 50%|█████     | 4/8 [00:03<00:02,  1.44it/s, avg_loss=0.167][A[A

 62%|██████▎   | 5/8 [00:03<00:02,  1.45it/s, avg_loss=0.167][A[A

 62%|██████▎   | 5/8 [00:03<00:02,  1.45it/s, avg_loss=0.162][A[A

 75%|███████▌  | 6/8 [00:03<00:01,  1.53it/s, avg_loss=0.162][A[A

 75%|███████▌  | 6/8 [00:04<00:01,  1.53it/s, avg_loss=0.16] [A[A

 88%|████████▊ | 7/8 [00:04<00:00,  1.60it/s, avg_loss=0.16][A[A

 88%|████████▊ | 7/8 [00:04<00:00,  1.60i

[E:11]	Train Loss: 0.160	Dev score: 0.848




  0%|          | 0/8 [00:00<?, ?it/s, avg_loss=0.202][A[A

 12%|█▎        | 1/8 [00:00<00:03,  2.09it/s, avg_loss=0.202][A[A

 12%|█▎        | 1/8 [00:01<00:03,  2.09it/s, avg_loss=0.184][A[A

 25%|██▌       | 2/8 [00:01<00:03,  2.00it/s, avg_loss=0.184][A[A

 25%|██▌       | 2/8 [00:01<00:03,  2.00it/s, avg_loss=0.176][A[A

 38%|███▊      | 3/8 [00:01<00:02,  1.91it/s, avg_loss=0.176][A[A

 38%|███▊      | 3/8 [00:02<00:02,  1.91it/s, avg_loss=0.172][A[A

 50%|█████     | 4/8 [00:02<00:02,  1.63it/s, avg_loss=0.172][A[A

 50%|█████     | 4/8 [00:03<00:02,  1.63it/s, avg_loss=0.16] [A[A

 62%|██████▎   | 5/8 [00:03<00:01,  1.62it/s, avg_loss=0.16][A[A

 62%|██████▎   | 5/8 [00:03<00:01,  1.62it/s, avg_loss=0.155][A[A

 75%|███████▌  | 6/8 [00:03<00:01,  1.65it/s, avg_loss=0.155][A[A

 75%|███████▌  | 6/8 [00:04<00:01,  1.65it/s, avg_loss=0.153][A[A

 88%|████████▊ | 7/8 [00:04<00:00,  1.45it/s, avg_loss=0.153][A[A

 88%|████████▊ | 7/8 [00:04<00:00,  1.45i

[E:12]	Train Loss: 0.153	Dev score: 0.848




  0%|          | 0/8 [00:00<?, ?it/s, avg_loss=0.192][A[A

 12%|█▎        | 1/8 [00:00<00:03,  1.95it/s, avg_loss=0.192][A[A

 12%|█▎        | 1/8 [00:01<00:03,  1.95it/s, avg_loss=0.176][A[A

 25%|██▌       | 2/8 [00:01<00:03,  1.91it/s, avg_loss=0.176][A[A

 25%|██▌       | 2/8 [00:01<00:03,  1.91it/s, avg_loss=0.168][A[A

 38%|███▊      | 3/8 [00:01<00:02,  1.86it/s, avg_loss=0.168][A[A

 38%|███▊      | 3/8 [00:02<00:02,  1.86it/s, avg_loss=0.164][A[A

 50%|█████     | 4/8 [00:02<00:02,  1.48it/s, avg_loss=0.164][A[A

 50%|█████     | 4/8 [00:03<00:02,  1.48it/s, avg_loss=0.153][A[A

 62%|██████▎   | 5/8 [00:03<00:02,  1.46it/s, avg_loss=0.153][A[A

 62%|██████▎   | 5/8 [00:03<00:02,  1.46it/s, avg_loss=0.148][A[A

 75%|███████▌  | 6/8 [00:03<00:01,  1.52it/s, avg_loss=0.148][A[A

 75%|███████▌  | 6/8 [00:04<00:01,  1.52it/s, avg_loss=0.147][A[A

 88%|████████▊ | 7/8 [00:04<00:00,  1.54it/s, avg_loss=0.147][A[A

 88%|████████▊ | 7/8 [00:04<00:00,  1.54

[E:13]	Train Loss: 0.146	Dev score: 0.848




  0%|          | 0/8 [00:00<?, ?it/s, avg_loss=0.183][A[A

 12%|█▎        | 1/8 [00:00<00:03,  2.00it/s, avg_loss=0.183][A[A

 12%|█▎        | 1/8 [00:01<00:03,  2.00it/s, avg_loss=0.169][A[A

 25%|██▌       | 2/8 [00:01<00:03,  1.75it/s, avg_loss=0.169][A[A

 25%|██▌       | 2/8 [00:01<00:03,  1.75it/s, avg_loss=0.161][A[A

 38%|███▊      | 3/8 [00:01<00:02,  1.68it/s, avg_loss=0.161][A[A

 38%|███▊      | 3/8 [00:02<00:02,  1.68it/s, avg_loss=0.158][A[A

 50%|█████     | 4/8 [00:02<00:02,  1.47it/s, avg_loss=0.158][A[A

 50%|█████     | 4/8 [00:03<00:02,  1.47it/s, avg_loss=0.147][A[A

 62%|██████▎   | 5/8 [00:03<00:02,  1.47it/s, avg_loss=0.147][A[A

 62%|██████▎   | 5/8 [00:04<00:02,  1.47it/s, avg_loss=0.142][A[A

 75%|███████▌  | 6/8 [00:04<00:01,  1.51it/s, avg_loss=0.142][A[A

 75%|███████▌  | 6/8 [00:04<00:01,  1.51it/s, avg_loss=0.141][A[A

 88%|████████▊ | 7/8 [00:04<00:00,  1.56it/s, avg_loss=0.141][A[A

 88%|████████▊ | 7/8 [00:04<00:00,  1.56

[E:14]	Train Loss: 0.140	Dev score: 0.848




  0%|          | 0/8 [00:00<?, ?it/s, avg_loss=0.175][A[A

 12%|█▎        | 1/8 [00:00<00:03,  1.88it/s, avg_loss=0.175][A[A

 12%|█▎        | 1/8 [00:01<00:03,  1.88it/s, avg_loss=0.162][A[A

 25%|██▌       | 2/8 [00:01<00:03,  1.84it/s, avg_loss=0.162][A[A

 25%|██▌       | 2/8 [00:01<00:03,  1.84it/s, avg_loss=0.155][A[A

 38%|███▊      | 3/8 [00:01<00:03,  1.64it/s, avg_loss=0.155][A[A

 38%|███▊      | 3/8 [00:02<00:03,  1.64it/s, avg_loss=0.152][A[A

 50%|█████     | 4/8 [00:02<00:02,  1.43it/s, avg_loss=0.152][A[A

 50%|█████     | 4/8 [00:03<00:02,  1.43it/s, avg_loss=0.141][A[A

 62%|██████▎   | 5/8 [00:03<00:02,  1.42it/s, avg_loss=0.141][A[A

 62%|██████▎   | 5/8 [00:04<00:02,  1.42it/s, avg_loss=0.137][A[A

 75%|███████▌  | 6/8 [00:04<00:01,  1.46it/s, avg_loss=0.137][A[A

 75%|███████▌  | 6/8 [00:04<00:01,  1.46it/s, avg_loss=0.136][A[A

 88%|████████▊ | 7/8 [00:04<00:00,  1.54it/s, avg_loss=0.136][A[A

 88%|████████▊ | 7/8 [00:04<00:00,  1.54

[E:15]	Train Loss: 0.135	Dev score: 0.848




  0%|          | 0/8 [00:00<?, ?it/s, avg_loss=0.169][A[A

 12%|█▎        | 1/8 [00:00<00:06,  1.06it/s, avg_loss=0.169][A[A

 12%|█▎        | 1/8 [00:01<00:06,  1.06it/s, avg_loss=0.157][A[A

 25%|██▌       | 2/8 [00:01<00:05,  1.14it/s, avg_loss=0.157][A[A

 25%|██▌       | 2/8 [00:02<00:05,  1.14it/s, avg_loss=0.15] [A[A

 38%|███▊      | 3/8 [00:02<00:04,  1.24it/s, avg_loss=0.15][A[A

 38%|███▊      | 3/8 [00:03<00:04,  1.24it/s, avg_loss=0.147][A[A

 50%|█████     | 4/8 [00:03<00:03,  1.11it/s, avg_loss=0.147][A[A

 50%|█████     | 4/8 [00:04<00:03,  1.11it/s, avg_loss=0.136][A[A

 62%|██████▎   | 5/8 [00:04<00:02,  1.20it/s, avg_loss=0.136][A[A

 62%|██████▎   | 5/8 [00:04<00:02,  1.20it/s, avg_loss=0.132][A[A

 75%|███████▌  | 6/8 [00:04<00:01,  1.29it/s, avg_loss=0.132][A[A

 75%|███████▌  | 6/8 [00:05<00:01,  1.29it/s, avg_loss=0.132][A[A

 88%|████████▊ | 7/8 [00:05<00:00,  1.37it/s, avg_loss=0.132][A[A

 88%|████████▊ | 7/8 [00:05<00:00,  1.37i

[E:16]	Train Loss: 0.131	Dev score: 0.848




  0%|          | 0/8 [00:00<?, ?it/s, avg_loss=0.162][A[A

 12%|█▎        | 1/8 [00:00<00:03,  2.06it/s, avg_loss=0.162][A[A

 12%|█▎        | 1/8 [00:01<00:03,  2.06it/s, avg_loss=0.151][A[A

 25%|██▌       | 2/8 [00:01<00:03,  1.89it/s, avg_loss=0.151][A[A

 25%|██▌       | 2/8 [00:01<00:03,  1.89it/s, avg_loss=0.145][A[A

 38%|███▊      | 3/8 [00:01<00:03,  1.65it/s, avg_loss=0.145][A[A

 38%|███▊      | 3/8 [00:02<00:03,  1.65it/s, avg_loss=0.143][A[A

 50%|█████     | 4/8 [00:02<00:02,  1.41it/s, avg_loss=0.143][A[A

 50%|█████     | 4/8 [00:03<00:02,  1.41it/s, avg_loss=0.132][A[A

 62%|██████▎   | 5/8 [00:03<00:02,  1.30it/s, avg_loss=0.132][A[A

 62%|██████▎   | 5/8 [00:04<00:02,  1.30it/s, avg_loss=0.128][A[A

 75%|███████▌  | 6/8 [00:04<00:01,  1.29it/s, avg_loss=0.128][A[A

 75%|███████▌  | 6/8 [00:05<00:01,  1.29it/s, avg_loss=0.128][A[A

 88%|████████▊ | 7/8 [00:05<00:00,  1.33it/s, avg_loss=0.128][A[A

 88%|████████▊ | 7/8 [00:05<00:00,  1.33

[E:17]	Train Loss: 0.127	Dev score: 0.848




  0%|          | 0/8 [00:00<?, ?it/s, avg_loss=0.159][A[A

 12%|█▎        | 1/8 [00:00<00:05,  1.25it/s, avg_loss=0.159][A[A

 12%|█▎        | 1/8 [00:01<00:05,  1.25it/s, avg_loss=0.149][A[A

 25%|██▌       | 2/8 [00:01<00:04,  1.24it/s, avg_loss=0.149][A[A

 25%|██▌       | 2/8 [00:02<00:04,  1.24it/s, avg_loss=0.143][A[A

 38%|███▊      | 3/8 [00:02<00:04,  1.24it/s, avg_loss=0.143][A[A

 38%|███▊      | 3/8 [00:03<00:04,  1.24it/s, avg_loss=0.14] [A[A

 50%|█████     | 4/8 [00:03<00:03,  1.14it/s, avg_loss=0.14][A[A

 50%|█████     | 4/8 [00:04<00:03,  1.14it/s, avg_loss=0.13][A[A

 62%|██████▎   | 5/8 [00:04<00:02,  1.16it/s, avg_loss=0.13][A[A

 62%|██████▎   | 5/8 [00:05<00:02,  1.16it/s, avg_loss=0.126][A[A

 75%|███████▌  | 6/8 [00:05<00:01,  1.22it/s, avg_loss=0.126][A[A

 75%|███████▌  | 6/8 [00:05<00:01,  1.22it/s, avg_loss=0.126][A[A

 88%|████████▊ | 7/8 [00:05<00:00,  1.28it/s, avg_loss=0.126][A[A

 88%|████████▊ | 7/8 [00:05<00:00,  1.28it/

Saving model at iteration 18 with best score 0.850
[E:18]	Train Loss: 0.125	Dev score: 0.850




  0%|          | 0/8 [00:00<?, ?it/s, avg_loss=0.155][A[A

 12%|█▎        | 1/8 [00:00<00:05,  1.35it/s, avg_loss=0.155][A[A

 12%|█▎        | 1/8 [00:01<00:05,  1.35it/s, avg_loss=0.145][A[A

 25%|██▌       | 2/8 [00:01<00:04,  1.29it/s, avg_loss=0.145][A[A

 25%|██▌       | 2/8 [00:02<00:04,  1.29it/s, avg_loss=0.14] [A[A

 38%|███▊      | 3/8 [00:02<00:04,  1.19it/s, avg_loss=0.14][A[A

 38%|███▊      | 3/8 [00:03<00:04,  1.19it/s, avg_loss=0.138][A[A

 50%|█████     | 4/8 [00:03<00:03,  1.03it/s, avg_loss=0.138][A[A

 50%|█████     | 4/8 [00:04<00:03,  1.03it/s, avg_loss=0.127][A[A

 62%|██████▎   | 5/8 [00:04<00:02,  1.01it/s, avg_loss=0.127][A[A

 62%|██████▎   | 5/8 [00:05<00:02,  1.01it/s, avg_loss=0.124][A[A

 75%|███████▌  | 6/8 [00:05<00:01,  1.04it/s, avg_loss=0.124][A[A

 75%|███████▌  | 6/8 [00:06<00:01,  1.04it/s, avg_loss=0.124][A[A

 88%|████████▊ | 7/8 [00:06<00:00,  1.07it/s, avg_loss=0.124][A[A

 88%|████████▊ | 7/8 [00:06<00:00,  1.07i

Saving model at iteration 19 with best score 0.865
[E:19]	Train Loss: 0.123	Dev score: 0.865
Restoring best model from iteration 19 with score 0.865
Finished Training
Accuracy: 0.865
        y=1    y=2   
 l=1     9      1    
 l=2    60     383   
[SUMMARY]
Best model: [0]
Best config: {'l2': 0.001, 'lr': 0.020702566920833793}
Best score: 0.8653421633554084


  max_idxs = np.where(diffs[i, :] < TOL)[0]


ValueError: 'a' cannot be empty unless no samples are taken

```
============================================================
[SUMMARY]
Best model: [0]
Best config: {'l2': 0.0001, 'lr': 0.0010025532524850966}
Best score: 0.9955849889624724
============================================================
Accuracy: 0.926
Precision: 0.675
Recall: 0.856
F1: 0.755
Roc-auc: 0.964
        y=1    y=2   
 l=1    137    66    
 l=2    23     982   
```

In [190]:
pred_dict = {}

for c, y in zip(test_cands[0], probs[:,0]):
    doc_file_id = f"{c[0].context.figure.document.name.lower()}.pdf::{os.path.basename(c[0].context.figure.url.lower())}"
    pred_dict[doc_file_id] = y

In [191]:
all_test_fig_id = set()

for doc in test_docs:
    for fig in doc.figures:
        doc_file_id = f"{doc.name.lower()}.pdf::{os.path.basename(fig.url.lower())}"
        all_test_fig_id.add(doc_file_id)

In [196]:
b = 0.7

tp = 0
fp = 0
fn = 0

for id in all_test_fig_id:
    if id in gt:
        p = True
    else: p = False
    if id in pred_dict and pred_dict[id] >= b:
        t = True
    else: t = False
    
    if t and p: tp += 1
    if t and not p: fp += 1
    if not t and p: fn += 1

prec = tp / (tp + fp)
rec = tp / (tp + fn)
f1 = 2 * prec * rec / (prec + rec)

tp, fp, fn, prec, rec, f1

(109, 72, 53, 0.6022099447513812, 0.6728395061728395, 0.6355685131195336)