In [1]:
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from functools import wraps, partial

model_name = "allenai/unifiedqa-v2-t5-large-1363200"  # you can specify the model size here
tokenizer = T5Tokenizer.from_pretrained(model_name)

model = T5ForConditionalGeneration.from_pretrained(model_name,
                                                   device_map='cuda:0')#'auto')


You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [2]:
def DEFAULT_COMPUTE_BIAS(self, query_length, key_length, device=None):
        """Compute binned relative position bias"""
        if device is None:
            device = self.relative_attention_bias.weight.device
        context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
        memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
        relative_position = memory_position - context_position  # shape (query_length, key_length)
        relative_position_bucket = self._relative_position_bucket(
            relative_position,  # shape (query_length, key_length)
            bidirectional=(not self.is_decoder),
            num_buckets=self.relative_attention_num_buckets,
            max_distance=self.relative_attention_max_distance,
        )
        values = self.relative_attention_bias(relative_position_bucket)  # shape (query_length, key_length, num_heads)
        values = values.permute([2, 0, 1]).unsqueeze(0)  # shape (1, num_heads, query_length, key_length)
        return values

In [3]:
import pickle

dataset_test = pickle.load(open('test_without_abcd.pkl', 'rb'))


In [4]:
input_ids = tokenizer.encode(dataset_test[0][0], return_tensors="pt")
print(tokenizer.convert_ids_to_tokens(input_ids[0]))

['▁A', '▁person', '▁wants', '▁to', '▁start', '▁saving', '▁money', '▁so', '▁that', '▁they', '▁can', '▁afford', '▁', 'a', '▁nice', '▁vacation', '▁at', '▁the', '▁end', '▁of', '▁the', '▁year', '.', '▁After', '▁looking', '▁over', '▁their', '▁budget', '▁and', '▁expenses', ',', '▁they', '▁decide', '▁the', '▁best', '▁way', '▁to', '▁save', '▁money', '▁is', '▁to', '▁', '<unk>', 'n', '▁(', '▁', ')', '▁make', '▁more', '▁phone', '▁calls', '▁(', '▁', ')', '▁quit', '▁eating', '▁lunch', '▁out', '▁(', '▁', ')', '▁buy', '▁less', '▁with', '▁', 'monopol', 'y', '▁money', '▁(', '▁', ')', '▁have', '▁lunch', '▁with', '▁friends', '</s>']


In [5]:
QUESTION_MAX_LENGTH = 45
MAX_ANSWER_LENGTH = 40

In [6]:
def check(input_text):
    input_ids = tokenizer.encode(input_text, return_tensors='pt')[0]
    tokens = tokenizer.convert_ids_to_tokens(input_ids)
    original = input_ids.tolist()
    anchor = []
    for i in range(len(tokens)):
        if (i < len(tokens) - 2 and tokens[i] == '▁(' and tokens[i + 1] == '▁'
                and tokens[i + 2] == ')') or original[i] == 1:
            anchor.append(i)
    # 0 1 2 3 4
    for x in reversed(range(1, 5)):
        if anchor[x] - anchor[x - 1] < MAX_ANSWER_LENGTH:
            [
                original.insert(anchor[x], 0)
                for _ in range(MAX_ANSWER_LENGTH - (anchor[x] - anchor[x - 1]))
            ]
        else:
            raise Exception('Wrong size')
    if anchor[0] < QUESTION_MAX_LENGTH:
        [
            original.insert(anchor[0], 0)
            for _ in range(QUESTION_MAX_LENGTH - anchor[0])
        ]
    else:
        raise Exception('Wrong size')
    return original


In [7]:
def new_compute_bias(self, query_length, key_length, device=None):
    """Compute binned relative position bias"""
    if device is None:
        device = self.relative_attention_bias.weight.device
    context_position = torch.arange(query_length,
                                    dtype=torch.long,
                                    device=device)[:, None]
    memory_position = torch.arange(key_length, dtype=torch.long,
                                   device=device)[None, :]

    relative_position = memory_position - context_position  # shape (query_length, key_length)
    # implementation='simple' 
    implementation='complicated'
    if self.is_decoder:
        pass
    elif implementation=='simple':
        # pass
        start_pos = QUESTION_MAX_LENGTH
        leng = MAX_ANSWER_LENGTH
        a = torch.arange(start_pos + leng * 0, start_pos + leng * 1, dtype=int)
        b = torch.arange(start_pos + leng * 1, start_pos + leng * 2, dtype=int)
        c = torch.arange(start_pos + leng * 2, start_pos + leng * 3, dtype=int)
        d = torch.arange(start_pos + leng * 3, start_pos + leng * 4, dtype=int)
        context_position_new = context_position.clone()
        context_position_new[b] = context_position_new[a]
        context_position_new[c] = context_position_new[a]
        context_position_new[d] = context_position_new[a]
        context_position_new[-1] = context_position_new[a[0]] + leng
        memory_position_new = context_position_new.clone().view(1, -1)
        relative_position_new = memory_position_new - context_position_new  # shape (query_length, key_length)
    if implementation=='complicated':
        mot=[a,b,c,d]
        for i,x in enumerate(mot):
            for j,y in enumerate(mot):
                if i!=j:
                    relative_position_new[x,y]=200 # no distance, a very special distance

        relative_position = relative_position_new
        
    relative_position_bucket = self._relative_position_bucket(
        relative_position,  # shape (query_length, key_length)
        bidirectional=(not self.is_decoder),
        num_buckets=self.relative_attention_num_buckets,
        max_distance=self.relative_attention_max_distance,
    )

    values = self.relative_attention_bias(
        relative_position_bucket
    )  # shape (query_length, key_length, num_heads)
    values = values.permute([2, 0, 1]).unsqueeze(
        0)  # shape (1, num_heads, query_length, key_length)
    return values



In [8]:
MODE = 'new'  #'old'
def set_mode(MODE):
    for part in ['encoder', 'decoder']:
        for block in getattr(model, part).block:
            for layer in block.layer:
                # only need to deal in the Encoder level
                if hasattr(
                        layer, 'SelfAttention'
                ) and layer.SelfAttention.has_relative_attention_bias:
                    layer.SelfAttention.compute_bias = partial(
                        new_compute_bias if MODE == 'new' else
                        DEFAULT_COMPUTE_BIAS, layer.SelfAttention)

In [9]:
model.hf_device_map

{'': device(type='cuda', index=0)}

In [10]:
import textwrap


def measure_unalike(arr, print_arr=False):
    n = len(arr)
    arr = pd.Series(arr).value_counts()
    if print_arr:
        print(arr)
    return 1 - ((arr / n)**2).sum()


question_to_do = 5
per_question = 20


def run_tokens(tokens):
    res = model.generate(tokens, max_new_tokens=MAX_ANSWER_LENGTH)
    return tokenizer.batch_decode(res, skip_special_tokens=True)


def run_model(input_string, **generator_args):
    input_ids = tokenizer.encode(input_string, return_tensors="pt")
    # print(torch.argwhere(input_ids[0]==2)[0,0]+2)
    res = model.generate(input_ids.to(0),
                         **generator_args,
                         max_new_tokens=MAX_ANSWER_LENGTH)
    return tokenizer.batch_decode(res, skip_special_tokens=True)


# 
# run_model(dataset_test[0][0])


In [15]:
print(textwrap.fill(dataset_test[0][0]))
input=check(dataset_test[0][0])
set_mode('old')
print('old ',run_tokens(torch.tensor(input).view(1,-1).to(0)))
set_mode('new')
print('new ',run_tokens(torch.tensor(input).view(1,-1).to(0)))

A person wants to start saving money so that they can afford a nice
vacation at the end of the year. After looking over their budget and
expenses, they decide the best way to save money is to \n ( ) make
more phone calls ( ) quit eating lunch out ( ) buy less with monopoly
money ( ) have lunch with friends
old  ['quit eating lunch out']
new  ['buy less lunch calls']


In [None]:
epochs=1
for epoch in epochs:
    