In [39]:
import logging
import time
from pathlib import Path

import torch
from transformers import BartForSequenceClassification, BartTokenizer, BartConfig

# internal libraries
from ressources import target_to_label

# set a seed value
torch.manual_seed(555)

logging.basicConfig(level=logging.WARNING)

In [40]:
text = """
Water infrastructure is particularly affected by underinvestment with highly negative potential consequences to the welfare of the population and the economy.
"""

In [41]:
results_dir = Path("results")
results_dir.mkdir(exist_ok=True)
now = time.time()

label_to_target = {v: k for k, v in target_to_label.items()}

In [42]:
targets = [k for k in target_to_label.keys()]
targets.append("0")

In [43]:
config = BartConfig.from_pretrained("valhalla/distilbart-mnli-12-9")
tokenizer = BartTokenizer.from_pretrained("valhalla/distilbart-mnli-12-9")
model = BartForSequenceClassification.from_pretrained("valhalla/distilbart-mnli-12-9")
model.eval()

BartForSequenceClassification(
  (model): BartModel(
    (shared): Embedding(50265, 1024, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(50265, 1024, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 1024)
      (layers): ModuleList(
        (0): BartEncoderLayer(
          (self_attn): BartAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
          (final_layer_norm): LayerNorm((10

In [44]:
def predict(premise, hypothesis):
    # run through model pre-trained on MNLI
    input_ids = tokenizer.encode(
        premise, hypothesis, truncation=True, return_tensors="pt"
    )
    with torch.no_grad():
        logits = model(input_ids)[0]

    # we throw away "neutral" (dim 1) and take the probability of
    # "entailment" (2) as the probability of the label being true
    entail_contradiction_logits = logits[:, [0, 2]]

    probs = entail_contradiction_logits.softmax(dim=1)
    true_prob = probs[:, 1].item() * 100
    logging.info(f"Probability that '{hypothesis}' is true: {true_prob:0.2f}%")

    return true_prob

In [45]:
labels = list(target_to_label.values())

results = {}

start_time = time.time()

for label in labels:
    # Build hypothesis
    hypothesis = "The context is " + label

    # Run prediction
    true_prob = predict(text, hypothesis)

    target = label_to_target[label]

    results[target] = true_prob

total_time = time.time() - start_time
logging.info(f"Total prediction time : {total_time:0.2f}s")

In [46]:
top_X = 5
top_X_targets = []

# find X targets with highest score
for _ in range(top_X):
    target_max_temp = max(results, key=results.get)

    top_X_targets.append(target_max_temp)
    results.pop(target_max_temp)

print(top_X_targets)
for t in top_X_targets:
   print(target_to_label[t])


['12.2', '12.6', '12.7', '16.6', '12.8']
use of natural resources
encourage companies to integrate sustainability into their reporting cycle
promote public procurement practices that are sustainable
develop effective, accountable and transparent institutions
ensure that people have the relevant information and awareness for sustainable development and lifestyles
