In [1]:
%pylab inline

Populating the interactive namespace from numpy and matplotlib


In [2]:
import torch
from sup_mmd.data import MMD_Dataset
import sys, json, re, os, glob, shutil
from commons.utils import get_logger
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from sup_mmd.model import LinearModel1, mmd_loss_pq, LinearModelComp1, mmd_loss_pq_comp
from sup_mmd.model import MMD, MMD_comp
from submodular.maximize import greedy as greedy_maximize
import pandas as pd
from copy import copy
from sup_mmd.functions import softmax, nz_median_dist, combine_kernels
import multiprocessing as mp
import pandas as pd

logger = get_logger("Infer")
GPU_MODE = False
import networkx as nx

import warnings
warnings.filterwarnings('ignore')

In [3]:
pattern_generic = re.compile(r'(mmdpq?)(_)(.?)(duc03|duc04|tac08|tac09)-([AB])_([xc])\.(\d+|x)_g(\d\.?\d*)_b(\d\.?\d*)_a(\d\.?\d*)(_)SF(b|x)(k|x)(c|x)') 
pattern_update = re.compile(r'(mmdpq?)-comp([01])\.(lin1|lin2)_(tac08|tac09)-([AB])_([xc])\.(\d+|x)_g(\d\.?\d*)_b(\d\.?\d*)_a(\d+)_l(\d\.?\d*)_SF(b|x)(k|x)(c|x)')   

TARGET_NAME = "y_hm_0.4"
BUDGET = 125 ##aftre compression, it will be less
#### ROUGE eval truncates, so > 100 words will be truncated

CACHE_ROOT = "./data/"

In [7]:
def infer(model_path, r, group):
    model_file = model_path.split("/")[-1]
    s = pattern_generic.search(model_file)
    generic = True ## MMD() or MMD() - lambda*MMD()
    if not s:
        s = pattern_update.search(model_file)
        if not s:
            logger.error("pattern not matched with both generic/update regexes, quitting " + model_file)
            return
        generic = False

    name = s.group(0)
    loss_name = s.group(1)

    assert loss_name in {"mmdpq"}
    train_dataset = s.group(4).lower()
    assert train_dataset in {"duc03", "duc04", "tac08", "tac09"}
    set_ = s.group(5)
    assert set_ in {"A", "B"}
    if not generic or set_ == "B":
        assert train_dataset in {"tac08", "tac09"}
    compress = s.group(6) == "c"
    split_seq = s.group(7)
    
    if split_seq != "x":
        logger.warning("please supply retrained model")

    gamma1 = float(s.group(8))
    beta = float(s.group(9))
    alpha_seq = s.group(10)
    lambdaa = 0.0

    lambdaa, diff, model_name = s.group(11), s.group(2), s.group(3)
    
    if not generic:
        assert set_ == "B"
        lambdaa = float(s.group(11))
        diff = int(s.group(2))
        model_name = s.group(3)
    
    BOOST_FIRST = s.group(12) == "b"
    KEYWORDS = s.group(13) == "k"
    comp_feats = s.group(14) =="c"
    
    logger.debug((name, list(s.groups()), (loss_name, diff, model_name, train_dataset, set_, compress, split_seq, gamma1, beta, alpha_seq, lambdaa, BOOST_FIRST, KEYWORDS, comp_feats) ) )
    
    dataset = {
        "duc03": "duc04", 
        "duc04": "duc03",
        "tac08": "tac09",
        "tac09": "tac08"
    }[train_dataset]

    dataset_name = "{}_{}".format(dataset, TARGET_NAME)

    logger.debug("loading data from " + dataset_name)
    data = MMD_Dataset.load(dataset_name, CACHE_ROOT, compress = compress)
    SURF_IDXS = data.surf_idxs(keywords = KEYWORDS, boost_first = BOOST_FIRST, comp = ( comp_feats and set_ == "B" ) )
    
    logger.info("surf feats: {}".format(
        ",".join( np.array(data.surf_names)[SURF_IDXS] )
    ))

    logger.debug("loading model from " + model_path)
    try:
        model, alpha, train_idxs, val_idxs, epochs = LinearModel1.load(len(SURF_IDXS), model_path)
    except:
        model, alpha, train_idxs, val_idxs, epochs = LinearModelComp1.load(len(SURF_IDXS), len(SURF_IDXS), model_path)

    idxs = np.arange( len(data) ).tolist()

    root = "./{}_{}/".format( dataset, set_ )

    if not os.path.exists(root + "summaries"):
        try:
            os.makedirs(root + "summaries")
        except:
            pass
#     shutil.copy2( model_path, root )

    logger.debug("Dataset and model loaded, begin inference with #topics={}, generic?={}".format( len(idxs), generic ))

    if GPU_MODE:
        model.cuda()

    ix = np.where(np.array(data.groups) == group)[0][0]
#     print(ix)
    # group = data.groups[ix]            
    subset = data.get_subset_df(group, set_ )
    write_df = subset[["position", "doc_sents", "sent_id", "group", "target", "set", "doc_id", "num_words", "R1.R", "R1.P", "R2.R", "R2.P", "nouns", "prpns"]]
    surf_names = np.array(data.surf_names)[SURF_IDXS]

    if generic:
        if train_dataset in ["duc03", "duc04"]:
            K, X, _, _ = data[ix]
        elif train_dataset in ["tac08", "tac09"]:
            if set_ == "A":
                # logger.info("A")
                K, _, _, X, _, _, _, _, _ = data[ix]
            else:
                # logger.info("B")
                _, K, _, _, X, _, _, _, _ = data[ix]
        
        K, X = K.squeeze(), X.squeeze()[:, SURF_IDXS]
        fg = model.forward( X )[0]
        K_combined = combine_kernels(K, alpha, gamma1) 
        mmd = MMD(K_combined, fg)
        K = torch.einsum('ijk,k->ij', K, alpha)

    else:
        KA, KB, KAB, XA, X, _, _, _, _ = data[ix]
        KA, XA = KA.squeeze(), XA.squeeze()[:, SURF_IDXS]
        KB, X = KB.squeeze(), X.squeeze()[:, SURF_IDXS]
        KAB = KAB.squeeze()
        
        fA, fg = model.forward( XA, X )
        KA_combined = combine_kernels(KA, alpha, gamma1)  
        KB_combined = combine_kernels(KB, alpha, gamma1)  
        KAB_combined = combine_kernels(KAB, alpha, gamma1)
        mmd = MMD_comp( KB_combined, KA_combined, KAB_combined, fg, fA, lambdaa = lambdaa, diff = diff)
        K = torch.einsum('ijk,k->ij', KB, alpha)

    write_df["nf"] = X[:, np.where(surf_names=="nf")[0]] #normalised
    write_df["lexrank"] = X[:, np.where(surf_names=="lexrank")[0]].numpy() + 1.0
    
    write_df["tfisf"] = X[:, np.where(surf_names=="tfisf")[0]].numpy() ## normalised
    write_df["btfisf"] = X[:, np.where(surf_names=="btfisf")[0]].numpy() #normalised
    write_df["scores"] = softmax(fg.detach().numpy()) * len(subset)

    lengths = subset["num_words"].values
    keys = None
    if compress:
        keys = [int(sid.split("-")[0]) for sid in subset["sent_id"]]
    S, cost = greedy_maximize(mmd, budget = BUDGET, 
                costs = copy(lengths), r = r, verbose = False, keys = keys)
    K = K.numpy()
    np.fill_diagonal(K, 0.0)
    return K, write_df, S

In [27]:
# group, T = "D0910", 0.15
# group, T = "D0908", 0.1
group, T = "D0929", 0.1
K, df, S = infer("tac09_A/mmdpq_tac08-A_x.x_g2.25_b0.08_a0_SFbxx.net", 0.01, group)
print ( K.shape[0], (K > T).sum() // 2, (K > T).sum() / ( K.shape[0] * K.shape[1] ), len(S) )

May-27 00:33:29 INFO [data:531]=> loaded from ./data//tac09_y_hm_0.4.pik
May-27 00:33:29 INFO [Infer:61]=> surf feats: rel_pos,pos1,pos2,pos3,pos4+,#words,par_start,#nouns,query_sim,tfisf,btfisf,lexrank


62 192 0.09989594172736732 5


In [28]:
G=nx.from_numpy_matrix( K * ( K >= T ) )
y = np.zeros(K.shape[0])
y[S] = 1
nx.set_node_attributes(G, dict(enumerate(df["scores"].values)), 'score')
nx.set_node_attributes(G, dict(enumerate(df["tfisf"].values)), 'tfisf')
nx.set_node_attributes(G, dict(enumerate(df["btfisf"].values)), 'btfisf')
nx.set_node_attributes(G, dict(enumerate(df["lexrank"].values)), 'lexrank')
nx.set_node_attributes(G, dict(enumerate(df["target"].values)), 'target')
nx.set_node_attributes(G, dict(enumerate(df["R2.R"].values / df["num_words"].values)), 'R2.R')
nx.set_node_attributes(G, dict(enumerate(df["R1.R"].values / df["num_words"].values)), 'R1.R')
nx.set_node_attributes(G, dict(enumerate(1 - df["position"].values / df["doc_sents"].values)), 'rel_pos')
nx.set_node_attributes(G, dict(enumerate(df["nouns"].values)), 'nouns')
nx.set_node_attributes(G, dict(enumerate(df["num_words"].values)), 'words')
nx.set_node_attributes(G, dict(enumerate(y)), 'summary')
nx.write_graphml(G, group + ".graphml")

In [8]:
df.head()

Unnamed: 0,position,doc_sents,sent_id,group,set,doc_id,num_words,R1.R,R1.P,R2.R,R2.P,nouns,prpns,nf,lexrank,tfisf,btfisf,scores
5922,0,10,0,D0929,A,XIN_ENG_20050919.0201,22,0.1126,0.5114,0.03288,0.1548,5,4,(),1.023223,0.154728,0.761594,1.015507
5923,1,10,1,D0929,A,XIN_ENG_20050919.0201,22,0.08758,0.4167,0.01015,0.05,6,2,(),0.596115,-0.214641,-0.308914,0.989346
5924,2,10,2,D0929,A,XIN_ENG_20050919.0201,37,0.18,0.5143,0.07328,0.2132,7,6,(),1.407246,0.956008,0.68444,1.007415
5925,3,10,3,D0929,A,XIN_ENG_20050919.0201,22,0.1051,0.5,0.04298,0.2125,5,6,(),0.884419,-0.04343,0.199197,0.996231
5926,4,10,4,D0929,A,XIN_ENG_20050919.0201,15,0.07503,0.5,0.02023,0.1429,4,2,(),0.436922,-0.786257,-0.825672,0.982294
