In [1]:
import matplotlib as mpl
import matplotlib.pyplot as plt

%matplotlib inline
mpl.style.use('bmh')

In [94]:
import torch
import math
import random

from tqdm import tqdm_notebook
from boltons.iterutils import pairwise
from scipy.stats import kendalltau

from sent_order.models.kt_regression import SentenceEncoder, Regressor, Corpus
from sent_order.perms import sample_uniform_perms

In [3]:
sent_encoder = torch.load(
    '../../data/models/new/kt-reg/sent_encoder.68.bin',
    map_location={'cuda:0': 'cpu'},
)

In [None]:
regressor = torch.load(
    '../../data/models/new/kt-reg/regressor.68.bin',
    map_location={'cuda:0': 'cpu'},
)

In [5]:
train = Corpus('../../data/train.json/', 10000)

100%|██████████| 10000/10000 [00:01<00:00, 8161.12it/s]


In [39]:
grafs = [g for g in train.grafs if len(g.sentences) == 8]

In [40]:
len(grafs)

639

In [113]:
graf = grafs[50]

In [114]:
[' '.join(s.tokens) for s in graf.sentences]

['Recently , the chiral - induced spin selectivity in molecular systems has attracted extensive interest among the scientific communities .',
 'Here , we investigate the effect of the gate voltage on spin - selective electron transport through the $ \\alpha$-helical peptide / protein molecule contacted by two nonmagnetic electrodes .',
 'Based on an effective model Hamiltonian and the Landauer - B\\"uttiker formula , we calculate the conductance and the spin polarization under an external electric field which is perpendicular to the helix axis of the $ \\alpha$-helical peptide / protein molecule .',
 'Our results indicate that both the magnitude and the direction of the gate field have a significant effect on the conductance and the spin polarization .',
 'The spin filtration efficiency can be improved by properly tuning the gate voltage , especially in the case of strong dephasing regime .',
 'And the spin polarization increases monotonically with the molecular length without the gate

In [115]:
sents = sent_encoder(graf.sentence_variables())

In [116]:
regressor(sents.unsqueeze(0)).data[0]

0.0003949403762817383

In [117]:
scores = []
for _ in range(100):
    perm = torch.randperm(8)
    pred = regressor(sents[perm].unsqueeze(0)).data[0]
    scores.append((perm.tolist(), pred))

In [118]:
order, score = sorted(scores, key=lambda x: x[1])[0]

In [119]:
for _ in tqdm_notebook(range(1000)):
    
    i1, i2 = random.sample(range(len(order)), 2)

    new_order = order.copy()
    new_order[i1], new_order[i2] = new_order[i2], new_order[i1]

    perm = torch.LongTensor(new_order)
    new_score = regressor(sents[perm].unsqueeze(0)).data[0]

    if new_score < score:
        score = new_score
        order = new_order

    print(score)

0.1978311985731125
0.1978311985731125
0.1978311985731125
0.1978311985731125
0.17335037887096405
0.17335037887096405
0.1705200970172882
0.14625321328639984
0.14625321328639984
0.14625321328639984
0.13627660274505615
0.13627660274505615
0.010469645261764526
0.010469645261764526
0.0062202513217926025
0.0062202513217926025
0.0062202513217926025
0.0062202513217926025
0.0062202513217926025
0.0062202513217926025
0.0062202513217926025
0.0062202513217926025
0.0062202513217926025
0.0062202513217926025
0.0062202513217926025
0.0062202513217926025
0.0062202513217926025
0.0062202513217926025
0.0062202513217926025
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.0028176903724

0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.00281769037

0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.002817690372467041
0.00281769037

In [120]:
order

[0, 4, 1, 2, 3, 5, 6, 7]

In [121]:
kendalltau(order, range(len(order)))

KendalltauResult(correlation=0.78571428571428559, pvalue=0.0064928577450838959)