# Instructions:

1. Run the wget command below to download the attention model
2. Untar the downloaded file and modify the `config.json` file
    - Change the key `model.type` to "textual_entailment"
3. Retar the downloaded file and name it "textual_entailment.tar.gz"
4. Run the next two lines of code to make sure everything worked

In [11]:
!wget https://storage.googleapis.com/allennlp-public-models/decomposable-attention-elmo-2020.04.09.tar.gz -O ./models/decomposable-attention-elmo-2020.04.09.tar.gz

--2021-05-01 16:56:09--  https://storage.googleapis.com/allennlp-public-models/decomposable-attention-elmo-2020.04.09.tar.gz
Loaded CA certificate '/etc/ssl/certs/ca-certificates.crt'
Resolving storage.googleapis.com (storage.googleapis.com)... 172.217.13.80, 172.253.63.128, 142.250.31.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|172.217.13.80|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 697659038 (665M) [application/x-gzip]
Saving to: ‘./models/decomposable-attention-elmo-2020.04.09.tar.gz’


2021-05-01 16:56:29 (32.9 MB/s) - ‘./models/decomposable-attention-elmo-2020.04.09.tar.gz’ saved [697659038/697659038]



In [1]:
from typing import List, Dict, Any
import torch

from allennlp_models.pair_classification.models import DecomposableAttention
from allennlp.data import TextFieldTensors
from allennlp.models.model import Model
from allennlp.nn.util import get_text_field_mask, masked_softmax, weighted_sum

@Model.register("textual_entailment")
class TextualEntailment(DecomposableAttention):
    
    def forward(  # type: ignore
        self,
        premise: TextFieldTensors,
        hypothesis: TextFieldTensors,
        label: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:

        """
        # Parameters
        premise : `TextFieldTensors`
            From a `TextField`
        hypothesis : `TextFieldTensors`
            From a `TextField`
        label : `torch.IntTensor`, optional (default = `None`)
            From a `LabelField`
        metadata : `List[Dict[str, Any]]`, optional (default = `None`)
            Metadata containing the original tokenization of the premise and
            hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.
        # Returns
        An output dictionary consisting of:
        label_logits : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_labels)` representing unnormalised log
            probabilities of the entailment label.
        label_probs : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_labels)` representing probabilities of the
            entailment label.
        loss : `torch.FloatTensor`, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise)
        hypothesis_mask = get_text_field_mask(hypothesis)

        if self._premise_encoder:
            embedded_premise = self._premise_encoder(embedded_premise, premise_mask)
        if self._hypothesis_encoder:
            embedded_hypothesis = self._hypothesis_encoder(embedded_hypothesis, hypothesis_mask)

        projected_premise = self._attend_feedforward(embedded_premise)
        projected_hypothesis = self._attend_feedforward(embedded_hypothesis)
        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._matrix_attention(projected_premise, projected_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = masked_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(embedded_premise, h2p_attention)

        premise_compare_input = torch.cat([embedded_premise, attended_hypothesis], dim=-1)
        hypothesis_compare_input = torch.cat([embedded_hypothesis, attended_premise], dim=-1)

        compared_premise = self._compare_feedforward(premise_compare_input)
        compared_premise = compared_premise * premise_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_premise = compared_premise.sum(dim=1)

        compared_hypothesis = self._compare_feedforward(hypothesis_compare_input)
        compared_hypothesis = compared_hypothesis * hypothesis_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_hypothesis = compared_hypothesis.sum(dim=1)

        aggregate_input = torch.cat([compared_premise, compared_hypothesis], dim=-1)
        label_logits = self._aggregate_feedforward(aggregate_input)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs,
            "aggregate_input": aggregate_input,
            "h2p_attention": h2p_attention,
            "p2h_attention": p2h_attention,
        }

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        if metadata is not None:
            output_dict["premise_tokens"] = [x["premise_tokens"] for x in metadata]
            output_dict["hypothesis_tokens"] = [x["hypothesis_tokens"] for x in metadata]

        return output_dict

In [9]:
from allennlp.predictors.predictor import Predictor
from models.textual_entailment import TextualEntailment 

predictor = Predictor.from_path("./models/textual_entailment.tar.gz")

ConfigurationError: Cannot register textual_entailment as Model; name already in use for TextualEntailment

In [21]:
outputs = model.default_predictor

In [27]:
from allennlp_models.pair_classification.predictors import TextualEntailmentPredictor
from allennlp_models.pair_classification.dataset_readers import SnliReader

reader = SnliReader()
predictor = TextualEntailmentPredictor(model, reader)

In [6]:
outputs = predictor.predict(
        premise="Two women are wandering along the shore drinking iced tea.",
        hypothesis="Two women are sitting on a blanket near some rocks talking about politics."
    )
outputs

{'label_logits': [-3.349426507949829, 4.429577827453613, 0.9683252573013306],
 'label_probs': [0.00040552925202064216,
  0.9691717624664307,
  0.030422702431678772],
 'aggregate_input': [0.29036271572113037,
  0.0028719918336719275,
  0.15912583470344543,
  0.05093297362327576,
  0.39576515555381775,
  0.6307034492492676,
  0.11962313950061798,
  0.0,
  0.18295618891716003,
  0.0,
  0.9975005388259888,
  0.0,
  0.3012509346008301,
  0.22798886895179749,
  0.0,
  1.1960465908050537,
  0.3020651042461395,
  0.5289392471313477,
  0.9563127756118774,
  0.18653690814971924,
  0.47025221586227417,
  0.07228167355060577,
  0.04665301367640495,
  0.015632718801498413,
  0.08078794926404953,
  0.08994587510824203,
  0.07313307374715805,
  0.0,
  0.9704352617263794,
  1.4826524257659912,
  0.43929317593574524,
  0.11609132587909698,
  0.4479944109916687,
  0.6941320300102234,
  0.0,
  0.21438650786876678,
  0.05199456214904785,
  0.4999952018260956,
  0.028935890644788742,
  0.040380336344242096

The Output above should be close to:

```
{'label_logits': [-3.349437952041626, 4.429553985595703, 0.9683454036712646],
 'label_probs': [0.00040553370490670204,
  0.9691703915596008,
  0.030424002557992935],
 'h2p_attention': [[0.6542813777923584,
   0.04181642830371857,
   0.044751670211553574,
   0.04774125665426254,
   0.032621681690216064,
   0.02951720543205738,
   0.0222022645175457,
   0.02251223288476467,
   0.022536294534802437,
   0.03395495191216469,
   0.025392329320311546,
   0.02267230674624443],
  [2.8718852263409644e-05,
   0.9997602105140686,
   2.671044239832554e-05,
   3.0569222872145474e-05,
   1.630610495340079e-05,
   1.6656915249768645e-05,
   2.2438669475377537e-05,
   1.677172986092046e-05,
   1.8310134692001157e-05,
   3.356780143803917e-05,
   1.4835497495369054e-05,
   1.4929611097613815e-05],
  [0.10177356749773026,
   0.08759545534849167,
   0.1118234395980835,
   0.11648961156606674,
   0.08897500485181808,
   0.10866256058216095,
   0.06278194487094879,
   0.06688541173934937,
   0.06301344931125641,
   0.07898420095443726,
   0.06009376421570778,
   0.05292144790291786],
  [0.0004623697604984045,
   0.0005366479163058102,
   0.0013094748137518764,
   0.9844369888305664,
   0.0037109607364982367,
   0.0008001072565093637,
   0.000759141577873379,
   0.00399616826325655,
   0.0005880572716705501,
   0.0022712009958922863,
   0.0006370970513671637,
   0.0004917914629913867],
  [0.021879466250538826,
   0.016263877972960472,
   0.043620962649583817,
   0.3232453167438507,
   0.12519335746765137,
   0.21580788493156433,
   0.03572555258870125,
   0.09783807396888733,
   0.03638046607375145,
   0.04659603163599968,
   0.022638076916337013,
   0.014810925349593163],
  [0.017246725037693977,
   0.011929051950573921,
   0.031438130885362625,
   0.2378472089767456,
   0.06237583979964256,
   0.4272131025791168,
   0.038840629160404205,
   0.05050673708319664,
   0.05539732053875923,
   0.04054877907037735,
   0.016226788982748985,
   0.010429643094539642],
  [0.0102377999573946,
   0.010660983622074127,
   0.012148184701800346,
   0.4098280966281891,
   0.024485282599925995,
   0.049786731600761414,
   0.36476778984069824,
   0.028135929256677628,
   0.031522396951913834,
   0.043300505727529526,
   0.008045729249715805,
   0.007080606184899807],
  [0.012575305067002773,
   0.013238711282610893,
   0.032401192933321,
   0.24183233082294464,
   0.13379737734794617,
   0.3817344009876251,
   0.06835024803876877,
   0.05440527945756912,
   0.015210701152682304,
   0.02560754492878914,
   0.012246701866388321,
   0.008600177243351936],
  [0.02550632879137993,
   0.02194177731871605,
   0.04952565208077431,
   0.0795864388346672,
   0.0690709576010704,
   0.5571142435073853,
   0.05692204087972641,
   0.03753482177853584,
   0.030629552900791168,
   0.04115668684244156,
   0.017573071643710136,
   0.013438398018479347],
  [0.0032999140676110983,
   0.0037106797099113464,
   0.00489591620862484,
   0.01287841983139515,
   0.009935688227415085,
   0.050064265727996826,
   0.6458783745765686,
   0.012672769837081432,
   0.16958104074001312,
   0.08052573353052139,
   0.0033406356815248728,
   0.0032165131997317076],
  [0.011913090944290161,
   0.009290369227528572,
   0.0199927669018507,
   0.42288175225257874,
   0.02823151834309101,
   0.02082984149456024,
   0.014663374982774258,
   0.33708497881889343,
   0.058575790375471115,
   0.05190961807966232,
   0.017143506556749344,
   0.007483404595404863],
  [0.028186608105897903,
   0.025777099654078484,
   0.05010043457150459,
   0.40535101294517517,
   0.07974492758512497,
   0.07758960872888565,
   0.0279350895434618,
   0.1694972813129425,
   0.04901993274688721,
   0.05365008860826492,
   0.021943798288702965,
   0.011204180307686329],
  [0.003791020717471838,
   0.0037876616697758436,
   0.005148179829120636,
   0.04601581394672394,
   0.007641100324690342,
   0.009489400312304497,
   0.01279785018414259,
   0.11390943080186844,
   0.20984220504760742,
   0.5805724263191223,
   0.0038965558633208275,
   0.0031084204092621803],
  [0.07601422071456909,
   0.06725233048200607,
   0.08519984036684036,
   0.13195280730724335,
   0.07711151987314224,
   0.08093569427728653,
   0.0706259161233902,
   0.09599754959344864,
   0.08104491233825684,
   0.07409657537937164,
   0.09086937457323074,
   0.06889919191598892],
  [0.08347291499376297,
   0.08220735937356949,
   0.08219696581363678,
   0.08517542481422424,
   0.07966078072786331,
   0.08381462097167969,
   0.0845530554652214,
   0.08125219494104385,
   0.08679300546646118,
   0.0875125303864479,
   0.08170206844806671,
   0.08165907859802246]],
 'p2h_attention': [[0.5794978737831116,
   0.038077544420957565,
   0.037115372717380524,
   0.01881016418337822,
   0.02877109684050083,
   0.03250500187277794,
   0.0274263434112072,
   0.028178894892334938,
   0.03750938922166824,
   0.02097991108894348,
   0.03461119532585144,
   0.050881993025541306,
   0.024034641683101654,
   0.02161012962460518,
   0.01999048702418804],
  [2.7932510420214385e-05,
   0.9997096061706543,
   2.4092194507829845e-05,
   1.646525925025344e-05,
   1.6129462892422453e-05,
   1.6956086255959235e-05,
   2.1539437511819415e-05,
   2.2373153115040623e-05,
   2.4335489797522314e-05,
   1.7792224753065966e-05,
   2.0356405002530664e-05,
   3.509387170197442e-05,
   1.8110436940332875e-05,
   1.4419360013562255e-05,
   1.4847881175228395e-05],
  [0.055058859288692474,
   0.04919420927762985,
   0.05664772167801857,
   0.07399989664554596,
   0.07967934757471085,
   0.082305908203125,
   0.04520675912499428,
   0.10085493326187134,
   0.10117033123970032,
   0.043238017708063126,
   0.0806855633854866,
   0.12563005089759827,
   0.04533838480710983,
   0.03364589437842369,
   0.027344146743416786],
  [0.0009355745278298855,
   0.000896776095032692,
   0.0009399468544870615,
   0.8861116766929626,
   0.009404795244336128,
   0.009918340481817722,
   0.024291837587952614,
   0.011989947408437729,
   0.002589575247839093,
   0.0018115945858880877,
   0.027183692902326584,
   0.01619011163711548,
   0.006454849150031805,
   0.0008300007320940495,
   0.0004513250896707177],
  [0.021218551322817802,
   0.015877211466431618,
   0.023829173296689987,
   0.11086936295032501,
   0.12089911103248596,
   0.08633403480052948,
   0.04817131906747818,
   0.22017863392829895,
   0.07459499686956406,
   0.04638965427875519,
   0.060234926640987396,
   0.10571739077568054,
   0.035576194524765015,
   0.01609918661415577,
   0.014010204002261162],
  [0.007183174602687359,
   0.006068067625164986,
   0.01088811457157135,
   0.008943452499806881,
   0.0779723972082138,
   0.2212289720773697,
   0.03664618358016014,
   0.23502855002880096,
   0.22510765492916107,
   0.08745463192462921,
   0.016627691686153412,
   0.0384838730096817,
   0.016530055552721024,
   0.0063220299780368805,
   0.005515076220035553],
  [0.0034150348510593176,
   0.005166654475033283,
   0.00397616159170866,
   0.0053633530624210835,
   0.008158475160598755,
   0.01271277666091919,
   0.16970250010490417,
   0.026598431169986725,
   0.01453727763146162,
   0.713119387626648,
   0.007398379035294056,
   0.008757534436881542,
   0.014090587384998798,
   0.0034868817310780287,
   0.003516556229442358],
  [0.007020140998065472,
   0.00782924797385931,
   0.008587962947785854,
   0.05723832547664642,
   0.045296810567379,
   0.03351450338959694,
   0.02653764933347702,
   0.042922645807266235,
   0.019434189423918724,
   0.028366943821310997,
   0.3448033630847931,
   0.10772685706615448,
   0.25426188111305237,
   0.009608656167984009,
   0.006850983947515488],
  [0.006401711143553257,
   0.007786095142364502,
   0.007370183244347572,
   0.007672714535146952,
   0.01534313801676035,
   0.03348563611507416,
   0.0270836241543293,
   0.0109315300360322,
   0.014446379616856575,
   0.3457837700843811,
   0.05458037927746773,
   0.02838052064180374,
   0.4266784191131592,
   0.007389492355287075,
   0.006666360888630152],
  [0.005955484230071306,
   0.00881356280297041,
   0.005704079754650593,
   0.01829722709953785,
   0.01213375199586153,
   0.015133792534470558,
   0.022971050813794136,
   0.011363179422914982,
   0.011985580436885357,
   0.10138233751058578,
   0.029865272343158722,
   0.019178668037056923,
   0.7288942933082581,
   0.004171451088041067,
   0.004150253254920244],
  [0.055401433259248734,
   0.0484546460211277,
   0.05398578941822052,
   0.06384691596031189,
   0.07333149760961533,
   0.07533685117959976,
   0.05309552699327469,
   0.06760141998529434,
   0.0636606439948082,
   0.052319228649139404,
   0.12269391119480133,
   0.09758078306913376,
   0.06085463985800743,
   0.06363722681999207,
   0.04819945618510246],
  [0.06747952848672867,
   0.06651806831359863,
   0.06485441327095032,
   0.06723154336214066,
   0.06544717401266098,
   0.06605444103479385,
   0.0637412816286087,
   0.0647592544555664,
   0.06640926003456116,
   0.06871878355741501,
   0.07306013256311417,
   0.0679658055305481,
   0.06622321903705597,
   0.06582117825746536,
   0.06571603566408157]],
 'premise_tokens': ['Two',
  'women',
  'are',
  'wandering',
  'along',
  'the',
  'shore',
  'drinking',
  'iced',
  'tea',
  '.',
  '@@NULL@@'],
 'hypothesis_tokens': ['Two',
  'women',
  'are',
  'sitting',
  'on',
  'a',
  'blanket',
  'near',
  'some',
  'rocks',
  'talking',
  'about',
  'politics',
  '.',
  '@@NULL@@'],
 'label': 'contradiction'}
 ```