In [1]:
import os
import time
import numpy as np
import os.path as osp

import torch
import torch.nn as nn
import torch.nn.functional as fn

from tensorboardX import SummaryWriter
from torch_geometric.data import DataLoader

from data.dataset3 import SkeletonDataset
from models.net import DualGraphEncoder
from optimizer import get_std_opt
from utils import make_checkpoint, load_checkpoint
from tqdm import tqdm, trange
from args import make_args

In [None]:
def run_epoch(data_loader,
              model,
              loss_compute,
              device,
              args,
              is_train=True,
              desc=None,
              num_literals=None,
              num_clauses=None):
    """Standard Training and Logging Function
    Args:
        data_loader: SATDataset
        model: nn.Module
        loss_compute: function
        device: int
        is_train: bool
        desc: str
        args: dict
        num_clauses: tensor
        num_literals: tensor
    """
    # torch.autograd.set_detect_anomaly(True)
    sat_r = []
    total_loss = 0
    start = time.time()
    bs = args.batch_size
    for i, batch in tqdm(enumerate(data_loader),
                         total=len(data_loader),
                         desc=desc):
        batch = batch.to(device)
        num_lit = num_literals[i * bs: (i + 1) * bs]
        num_cls = num_clauses[i * bs: (i + 1) * bs]
        # model.encoder.reset()
        gr_idx_lit = torch.cat([torch.tensor([i] * num_lit[i]) for i in range(num_lit.size(0))]).to(device)
        gr_idx_cls = torch.cat([torch.tensor([i] * num_cls[i]) for i in range(num_cls.size(0))]).to(device)
        with torch.set_grad_enabled(is_train):
            adj_pos, adj_neg = batch.edge_index_pos, batch.edge_index_neg
            xv = model(batch, args)
            loss, sm = loss_compute(xv, adj_pos, adj_neg, batch.xc.size(0), gr_idx_cls[: batch.xc.size(0)], is_train)
            total_loss += loss
        if i == 0:
            sat = 100 * (sm // 0.50001).mean().item()
            sat_r.append(sat)
            print("Sat Rate: ", sat, "%")
    elapsed = time.time() - start
    ms = 'average loss' if is_train else 'accuracy '
    print(ms + ': {}; average time: {}'.format(total_loss / len(data_loader.dataset),
                                               elapsed / len(data_loader.dataset)))

    return total_loss, sat_r