In [42]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import RobertaTokenizer,RobertaModel, RobertaConfig

tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
model = RobertaModel.from_pretrained("roberta-base")

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.dense.weight', 'lm_head.bias', 'lm_head.dense.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.bias', 'lm_head.layer_norm.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [28]:
import torch
import pandas as pd

In [29]:
df = pd.read_csv('train.csv')

In [30]:
df.head()

Unnamed: 0,id,url_legal,license,excerpt,target,standard_error
0,c12129c31,,,When the young people returned to the ballroom...,-0.340259,0.464009
1,85aa80a4c,,,"All through dinner time, Mrs. Fayre was somewh...",-0.315372,0.480805
2,b69ac6792,,,"As Roger had predicted, the snow departed as q...",-0.580118,0.476676
3,dd1000b26,,,And outside before the palace a great garden w...,-1.054013,0.450007
4,37c1b32fb,,,Once upon a time there were Three Bears who li...,0.247197,0.510845


In [18]:
class AttentionHead(torch.nn.Module):
    def __init__(self, in_features, hidden_dim, num_targets):
        super().__init__()
        self.in_features = in_features
        self.middle_features = hidden_dim

        self.W = torch.nn.Linear(in_features, hidden_dim)
        self.V = torch.nn.Linear(hidden_dim, 1)
        self.out_features = hidden_dim

    def forward(self, features):
        att = torch.tanh(self.W(features))

        score = self.V(att)

        attention_weights = torch.softmax(score, dim=1)

        context_vector = attention_weights * features
        context_vector = torch.sum(context_vector, dim=1)

        return context_vector


In [64]:
df['excerpt'].head(1).values.tolist()[0]

'When the young people returned to the ballroom, it presented a decidedly changed appearance. Instead of an interior scene, it was a winter landscape.\nThe floor was covered with snow-white canvas, not laid on smoothly, but rumpled over bumps and hillocks, like a real snow field. The numerous palms and evergreens that had decorated the room, were powdered with flour and strewn with tufts of cotton, like snow. Also diamond dust had been lightly sprinkled on them, and glittering crystal icicles hung from the branches.\nAt each end of the room, on the wall, hung a beautiful bear-skin rug.\nThese rugs were for prizes, one for the girls and one for the boys. And this was the game.\nThe girls were gathered at one end of the room and the boys at the other, and one end was called the North Pole, and the other the South Pole. Each player was given a small flag which they were to plant on reaching the Pole.\nThis would have been an easy matter, but each traveller was obliged to wear snowshoes.'

In [70]:
pt_batch = tokenizer(
    df['excerpt'].head(10).values.tolist(),
    padding=True,
    truncation=True,
    max_length=512,
    return_tensors="pt"
)

In [71]:
pt_batch

{'input_ids': tensor([[   0, 1779,    5,  ...,    1,    1,    1],
        [   0, 3684,  149,  ...,   69,    4,    2],
        [   0, 1620, 6682,  ...,    1,    1,    1],
        ...,
        [   0, 4148,    5,  ...,    1,    1,    1],
        [   0,  133, 2786,  ...,    1,    1,    1],
        [   0, 3762,  183,  ...,    1,    1,    1]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])}

In [72]:
x = model(**pt_batch)[0]

In [75]:
x

tensor([[[-0.0676,  0.0801,  0.0089,  ..., -0.0234, -0.0443, -0.0295],
         [-0.0136,  0.2424, -0.1040,  ..., -0.4047,  0.0952, -0.0992],
         [ 0.0357, -0.0299,  0.1018,  ..., -0.2944,  0.0626, -0.1154],
         ...,
         [ 0.0610, -0.0185,  0.1139,  ...,  0.2247,  0.1071,  0.0041],
         [ 0.0610, -0.0185,  0.1139,  ...,  0.2247,  0.1071,  0.0041],
         [ 0.0610, -0.0185,  0.1139,  ...,  0.2247,  0.1071,  0.0041]],

        [[-0.0829,  0.1356,  0.0139,  ..., -0.0217, -0.0603, -0.0491],
         [ 0.1458, -0.1178,  0.0553,  ...,  0.0597,  0.2168,  0.1527],
         [ 0.1772,  0.1132,  0.1627,  ..., -0.3505,  0.0937,  0.0639],
         ...,
         [ 0.3085,  0.2996,  0.0602,  ...,  0.0091,  0.0128,  0.3206],
         [-0.0781,  0.1418, -0.0144,  ..., -0.0607, -0.0704, -0.0931],
         [ 0.0287,  0.0712,  0.0572,  ...,  0.0983,  0.0948, -0.0255]],

        [[-0.0928,  0.1262,  0.0197,  ..., -0.0684, -0.0534, -0.0382],
         [ 0.0978,  0.0350, -0.0059,  ..., -0

In [73]:
head = AttentionHead(768,768,1)

In [78]:
head(x).tolist()[0]

[0.04364671930670738,
 0.05498413369059563,
 0.03521905466914177,
 -0.140118807554245,
 0.19014883041381836,
 -0.03924347832798958,
 -0.00560434814542532,
 0.16887755692005157,
 -0.03859715908765793,
 0.05036912485957146,
 -0.07033732533454895,
 -0.09216777980327606,
 0.0690608024597168,
 -0.07355324178934097,
 0.14234335720539093,
 0.08793966472148895,
 0.0678233802318573,
 0.1501721739768982,
 -0.061296552419662476,
 -0.07584792375564575,
 0.02649185061454773,
 0.07162951678037643,
 -0.08575887233018875,
 0.03568442538380623,
 -0.02945471927523613,
 -0.013446411117911339,
 0.060682233422994614,
 0.09183622151613235,
 0.02391025796532631,
 -0.04967126622796059,
 -0.04964624345302582,
 0.002295393031090498,
 0.05681797116994858,
 0.05402352288365364,
 -0.0596894808113575,
 0.03575623780488968,
 0.06461308896541595,
 0.030838683247566223,
 0.3308424949645996,
 0.03963182121515274,
 -0.05115049332380295,
 -0.04831239581108093,
 -0.058711208403110504,
 0.020512422546744347,
 -0.0177183337