In [1]:
import pandas as pd
import transformers 
import numpy as np
import torch
from collections import OrderedDict 

import sys

sys.path.insert(0, '..')

from roberta_model import RobertaConfig
from decompose_roberta import RobertaForSequenceClassificationDecomposed, decomp_activation
from decompose_roberta_mixed import RobertaForSequenceClassificationMixed
from analysis.preprocess_input import split_pos_neg_contributions


In [2]:
labels = ['awareness', 'change_of_location', 'change_of_state',
       'change_of_possession', 'existed_after', 'existed_before',
       'existed_during', 'instigation', 'sentient', 'volition']

# labels = ['awareness',
#        'change_of_location', 'change_of_state', 'changes_possession',
#        'created', 'destroyed', 'existed_after', 'existed_before',
#        'existed_during', 'exists_as_physical', 'instigation',
#        'location_of_event', 'makes_physical_contact', 'manipulated_by_another',
#        'predicate_changed_argument', 'sentient', 'stationary', 'volition']


In [3]:

path = "../../combined_SPRL_roberta-dropout=0.1"
model = transformers.AutoModelForSequenceClassification.from_pretrained(path)

new_state_dict = OrderedDict()
for key, value in model.state_dict().items():
    # rename weight values in state_dict from roberta to bert
    new_key = key.replace("roberta", "bert")
    new_key = new_key.replace(
        "classifier.dense", "bert.pooler.dense").replace(
            "classifier.out_proj", "classifier")
    # ref : https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
    new_key = new_key.replace("LayerNorm.weight",
                           "LayerNorm.gamma").replace("LayerNorm.bias",
                                                       "LayerNorm.beta")
    new_state_dict[new_key] = value


In [4]:
decomposed_model = RobertaForSequenceClassificationDecomposed(
    config= RobertaConfig.from_dict(model.config.to_dict()),
    debug=False, 
    num_labels=len(labels) * 3
    )

tokenizer = transformers.AutoTokenizer.from_pretrained(path + "/tokenizer")
model = decomposed_model.model
model.load_state_dict(new_state_dict)
model.eval()


RobertaForSequenceClassification(
  (bert): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50267, 1024)
      (token_type_embeddings): Embedding(1, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-12)
      (dropout): Dropout(p=0.1, inplace=False)
      (position_embeddings): Embedding(514, 1024)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-23): 24 x RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024,), eps=1e-12)
  

In [318]:
verb = ["<p>", "followed"]
arg = ["<a>", "the", "man",]
# arg = ["<a>", "The", "police", "dog"]
# encoded_input = tokenizer("The guide dog<p> helped<p><a> the blind man<a> cross the road", return_tensors='pt')['input_ids']
# encoded_input = tokenizer("<a>The police dog<a><p> followed<p> the man.", return_tensors='pt')['input_ids']
encoded_input = tokenizer("The police dog<p> followed<p><a> the man<a> using his scent.", return_tensors='pt')['input_ids']
text = list(map(lambda x: x.strip("Ġ"), tokenizer.convert_ids_to_tokens(encoded_input[0])))
print(text)

with torch.no_grad():
    logits = model(encoded_input)
    logits = torch.Tensor(logits.reshape(logits.shape[0], 3, len(labels)))

    # get probabilities using softmax
    probs = torch.softmax(logits, axis=1).squeeze()
    y_pred = torch.argmax(probs, axis=0)
for l, y, p in zip(labels, y_pred, probs.T):
    print(l, y.item(), p)

probs


['<s>', 'The', 'police', 'dog', '<p>', 'followed', '<p>', '<a>', 'the', 'man', '<a>', 'using', 'his', 'scent', '.', '</s>']
awareness 2 tensor([0.0219, 0.0474, 0.9307], dtype=torch.float64)
change_of_location 2 tensor([0.0025, 0.0065, 0.9910], dtype=torch.float64)
change_of_state 0 tensor([0.6040, 0.3324, 0.0636], dtype=torch.float64)
change_of_possession 0 tensor([0.9598, 0.0173, 0.0229], dtype=torch.float64)
existed_after 2 tensor([0.0031, 0.0037, 0.9932], dtype=torch.float64)
existed_before 2 tensor([1.2367e-04, 1.6266e-04, 9.9971e-01], dtype=torch.float64)
existed_during 2 tensor([2.4238e-04, 1.9480e-04, 9.9956e-01], dtype=torch.float64)
instigation 0 tensor([0.5514, 0.0779, 0.3707], dtype=torch.float64)
sentient 2 tensor([5.6175e-04, 7.0997e-04, 9.9873e-01], dtype=torch.float64)
volition 0 tensor([0.5489, 0.0892, 0.3619], dtype=torch.float64)


tensor([[2.1874e-02, 2.5167e-03, 6.0404e-01, 9.5978e-01, 3.1099e-03, 1.2367e-04,
         2.4238e-04, 5.5140e-01, 5.6175e-04, 5.4892e-01],
        [4.7431e-02, 6.4534e-03, 3.3239e-01, 1.7348e-02, 3.6955e-03, 1.6266e-04,
         1.9480e-04, 7.7901e-02, 7.0997e-04, 8.9225e-02],
        [9.3069e-01, 9.9103e-01, 6.3569e-02, 2.2871e-02, 9.9319e-01, 9.9971e-01,
         9.9956e-01, 3.7070e-01, 9.9873e-01, 3.6186e-01]], dtype=torch.float64)

In [319]:
logits = torch.zeros(len(encoded_input.squeeze()), 3, len(labels))

for i in range(len(encoded_input.squeeze())):
    mask = torch.zeros_like(encoded_input.squeeze())
    mask[i] = 1
    
    beta_mask = torch.stack([mask, 1 - mask]).unsqueeze(0)

    with torch.no_grad():
        l = decomposed_model(input_ids=encoded_input,
                        beta_mask=beta_mask,
                        num_contributions=2).squeeze()
        l = l[0].reshape(3, len(labels))

    logits[i] = l

# logits -= logits.mean(axis=[0, 1], keepdim=True)
# logits /= logits.std(axis=[0, 1],  keepdim=True)
# logits = torch.clip(logits, min=1e-2)


In [None]:
# # l = np.maximum(0, logits[:, [0,2], :])
# # l[:, 0] += np.maximum(0, logits[:, 1])
# # l = np.maximum(1e-2, l.unsqueeze(1).numpy())

# # l = split_pos_neg_contributions(torch.tensor(np.copy(logits)).unsqueeze(1).numpy())
# # # l = np.maximum(l, 1e-2)
# # l = np.maximum(l, 1e-2)
# # logits_ = l / l.sum((0, 1), keepdims=True)
# # l = np.maximum(l, 1e-2)
# print(logits_.shape)

# # neg = np.maximum(0, logits[:, 0, :]) - np.minimum(0, logits[:, 2, :])
# # pos = np.maximum(0, logits[:, 2, :]) - np.minimum(0, logits[:, 0, :])
# # l = torch.stack([neg, pos]).unsqueeze(1)
# # l.shape


(16, 1, 2, 10)
(16, 1, 2, 10)


In [320]:
# logits_ = logits / logits.sum((0, 1), keepdims=True)
# logits_ = logits / logits.sum((0, 1, 2), keepdims=True)
neg = np.maximum(0, logits[:, 0, :]) - np.minimum(0, logits[:, 2, :])
pos = np.maximum(0, logits[:, 2, :]) - np.minimum(0, logits[:, 0, :])
l = np.maximum(1e-2, torch.stack([neg, pos], axis=1).unsqueeze(1))
l.shape


torch.Size([16, 1, 2, 10])

In [321]:
df = pd.DataFrame(l.squeeze().permute(1, 2, 0).T.tolist(), columns=labels)
df["word"] = text

df = df.melt(id_vars = "word", var_name="property", value_name="logits")
# df[["negative", "positive"]] = df["logits"].tolist()

# df = df.drop(columns="logits")
df = df.set_index(["word", "property"])
df


Unnamed: 0_level_0,Unnamed: 1_level_0,logits
word,property,Unnamed: 2_level_1
<s>,awareness,"[0.3133496344089508, 0.009999999776482582]"
The,awareness,"[0.3642670214176178, 0.029435209929943085]"
police,awareness,"[0.28553956747055054, 0.5853988528251648]"
dog,awareness,"[0.009999999776482582, 1.1203675270080566]"
<p>,awareness,"[0.009999999776482582, 2.0951318740844727]"
...,...,...
using,volition,"[0.009999999776482582, 0.7409586906433105]"
his,volition,"[0.009999999776482582, 1.521207332611084]"
scent,volition,"[0.009999999776482582, 0.8599779009819031]"
.,volition,"[1.8890135288238525, 0.009999999776482582]"


In [322]:
links = []
# colors = {
#     "negative": "rgba(0, 202, 255, 0.5)",
#     "neutral": "rgba(112, 255, 145, 0.5)",
#     "positive": "rgba(255, 0, 250, 0.3)"
# }

property = "instigation"
for i, word in enumerate(text):
    for j, value in enumerate(df.loc[(word, property)]["logits"][0]):
        if word in verb:
            color =  "rgba(255, 0, 250, 0.5)"
        elif word in arg:
            color = "rgba(112, 255, 145, 0.5)"
        else:
            color = "rgba(0, 202, 255, 0.3)"
        # if j == 0:
        #     color =  "rgba(255, 0, 250, 0.3)"
        # else:
        #     color = "rgba(112, 255, 145, 0.5)"
        links.append(
            {
                'source': i,
                'target': len(text) + j,
                'value': value,
                'color': color,
                "word": word
            }
        )

links



indexing past lexsort depth may impact performance.


Series.__getitem__ treating keys as positions is deprecated. In a future version, integer keys will always be treated as labels (consistent with DataFrame behavior). To access a value by position, use `ser.iloc[pos]`



[{'source': 0,
  'target': 16,
  'value': 0.5685617327690125,
  'color': 'rgba(0, 202, 255, 0.3)',
  'word': '<s>'},
 {'source': 0,
  'target': 17,
  'value': 0.009999999776482582,
  'color': 'rgba(0, 202, 255, 0.3)',
  'word': '<s>'},
 {'source': 1,
  'target': 16,
  'value': 0.009999999776482582,
  'color': 'rgba(0, 202, 255, 0.3)',
  'word': 'The'},
 {'source': 1,
  'target': 17,
  'value': 0.2686059772968292,
  'color': 'rgba(0, 202, 255, 0.3)',
  'word': 'The'},
 {'source': 2,
  'target': 16,
  'value': 0.4956572651863098,
  'color': 'rgba(0, 202, 255, 0.3)',
  'word': 'police'},
 {'source': 2,
  'target': 17,
  'value': 0.166993647813797,
  'color': 'rgba(0, 202, 255, 0.3)',
  'word': 'police'},
 {'source': 3,
  'target': 16,
  'value': 0.009999999776482582,
  'color': 'rgba(0, 202, 255, 0.3)',
  'word': 'dog'},
 {'source': 3,
  'target': 17,
  'value': 0.14988034963607788,
  'color': 'rgba(0, 202, 255, 0.3)',
  'word': 'dog'},
 {'source': 4,
  'target': 16,
  'value': 0.30049222

In [328]:
import plotly.graph_objects as go
links = pd.DataFrame(links)

replace = {
    "<a>": "<|arg|>",
    "<p>": "<|pred|>"
}

fig = go.Figure(
    go.Sankey(
        arrangement = "snap", 
        node={"label": [replace[t] if t in replace else t for t in text ] + ["negative", "positive"],
              'x': [0.1] * (len(text)) + [0.5] * 2,
              "y": list(np.linspace(0.01,0.99,len(text))) + list(np.linspace(0.01,0.99,2)),
              'color': "grey",
              'pad':10
              },
            #   "y": [],
        link={
            "source": links["source"].tolist(),
            "target": links["target"].tolist(),
            "value": links["value"].tolist(),
            "color": links["color"].tolist(),
        },
    )
)

# fig.update_xaxes(automargin=True)
fig.update_layout(
    height=500,
    width=1000,
    autosize=False,
    )

fig.show()


In [None]:
mixed_model = RobertaForSequenceClassificationMixed(
    config= RobertaConfig.from_dict(model.config.to_dict()),
    state_dict=new_state_dict,
    segment_layer=23,
    debug=False, 
    num_labels=len(labels) * 3
    )


In [None]:
mixed_model.segment_layer = 24
logits = torch.zeros(len(encoded_input.squeeze()), 3, len(labels))

for i in range(len(encoded_input.squeeze())):
    mask = torch.zeros_like(encoded_input.squeeze())
    mask[i] = 1
    
    beta_mask = torch.stack([mask, 1 - mask]).unsqueeze(1).unsqueeze(3)
    with torch.no_grad():
        l = mixed_model(input_ids=encoded_input,
                        beta_mask=beta_mask, attention_mask=None,
                        num_contributions=2)["logits"]
        
        l = l[0].reshape(3, len(labels))
    logits[i] = l

logits -= logits.mean(axis=[0, 1], keepdim=True)
logits /= logits.std(axis=[0, 1], keepdim=True)
logits = torch.clip(logits, min=1e-3)


In [None]:
df = pd.DataFrame(logits.transpose(1,2).tolist(), columns=labels)
df["word"] = text

role = "volition"
df = df[["word",role]]
df["negative"] = df[role].apply(lambda x: x[0])
df["neutral"] = df[role].apply(lambda x: x[1])
df["positive"] = df[role].apply(lambda x: x[2])

df = df.drop(columns=role).set_index("word")
df


Unnamed: 0_level_0,negative,neutral,positive
word,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
<s>,2.149695,0.001,0.001
The,0.461468,0.035667,0.183111
creep,0.479681,0.14303,0.195693
<p>,0.486195,0.001,0.175133
stalk,0.54704,0.001,0.111457
ed,0.53068,0.001,0.138636
<p>,0.560194,0.001,0.001
<a>,0.557767,0.001,0.001
my,0.521977,0.001,0.001
girlfriend,0.480594,0.001,0.001


In [None]:
links = []
# colors = {
#     "negative": "rgba(0, 202, 255, 0.5)",
#     "neutral": "rgba(112, 255, 145, 0.5)",
#     "positive": "rgba(255, 0, 250, 0.3)"
# }

for i, word in enumerate(text):
    for j, value in enumerate(df.loc[word]):
        if word in verb:
            color =  "rgba(255, 0, 250, 0.5)"
        elif word in arg:
            color = "rgba(112, 255, 145, 0.5)"
        else:
            color = "rgba(0, 202, 255, 0.3)"
        links.append(
            {
                'source': i,
                'target': len(text) + j,
                'value': value,
                'color': color
            }
        )


In [None]:
import plotly.graph_objects as go
links = pd.DataFrame(links)

fig = go.Figure(
    go.Sankey(
        arrangement = "snap", 
        node={"label": text + df.columns.tolist(),
              'x': [0.1] * (len(text) - 4) + [0.5] * 3,
              "y": list(np.linspace(0.01,0.99,len(text) - 4)) + list(np.linspace(0.01,0.99,3)),
              'color': "grey",
            #   'pad':5
              },
            #   "y": [],
        link={
            "source": links["source"].tolist(),
            "target": links["target"].tolist(),
            "value": links["value"].tolist(),
            "color": links["color"].tolist(),
        },
    )
)

# fig.update_xaxes(automargin=True)
fig.update_layout(height=600)

fig.show()
