In [13]:
from dataset import get_dataloader
import torch

In [39]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base",torch_dtype=torch.bfloat16)

In [40]:
model.dtype

torch.bfloat16

In [41]:
test_dataloader = get_dataloader(tokenizer=tokenizer,split="test",is_encoder_decoder=True)

Loaded 90281 examples from test split


In [42]:
for example in test_dataloader:
    example = example
    break

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

In [43]:
example

{'input_ids': tensor([[   27,     7,     8,   826,  2493, 10998,    42, 10747,     7,    15,
            58,    71,  2103,    19,  5871,    46,  7298,     3,  6443,   251,
            81,    46,   605,     5,     1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1]]), 'labels': tensor([[10998,     1]]), 'loss_weight_mask': tensor([[1., 1.]])}

In [44]:
tokenizer.batch_decode(example['input_ids'],skip_special_tokens=True)

['Is the following statement True or False? A notice is commonly an announcement containing information about an event.']

In [45]:
tokenizer.batch_decode(example['labels'],skip_special_tokens=True)

['True']

In [46]:
model.to("cuda")

T5ForConditionalGeneration(
  (shared): Embedding(32128, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=768, out_features=2048, bias=False)
              (wi_1): Linear(in_features=768, out_features=2048, bias=False)
              (wo):

In [54]:
with torch.autocast(
                        "cuda" if torch.cuda.is_available() else "cpu",dtype=model.dtype
                    ):
    encoder_output = model.get_encoder()(
                                input_ids=example["input_ids"].to(model.device),
                                attention_mask=example["attention_mask"].to(model.device),
                            )

In [55]:
with torch.autocast(
                        "cuda" if torch.cuda.is_available() else "cpu",dtype=model.dtype
                    ):
    decoder_args = {
                                "attention_mask": example["attention_mask"],
                                "use_cache": False,
                                "encoder_outputs": encoder_output,
                            }

In [56]:
with torch.autocast(
                        "cuda" if torch.cuda.is_available() else "cpu",dtype=model.dtype
                    ):
    gen_inputs = model.prepare_inputs_for_generation(
                                input_ids=torch.tensor(
                                    [[tokenizer.pad_token_id]] * len(example["input_ids"])
                                ).to(example["input_ids"].device),
                                **decoder_args,
                            )

In [57]:
gen_inputs

{'decoder_input_ids': tensor([[0]]),
 'past_key_values': None,
 'encoder_outputs': BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=tensor([[[-0.1025,  0.0762,  0.1279,  ..., -0.0052,  0.0305,  0.1602],
          [-0.0286,  0.1104,  0.1436,  ...,  0.0845, -0.0486,  0.1465],
          [-0.0154,  0.0003,  0.0105,  ...,  0.0022,  0.0018, -0.0055],
          ...,
          [-0.1196, -0.0376, -0.0500,  ...,  0.0408,  0.1543,  0.1680],
          [ 0.0054, -0.1025,  0.0104,  ...,  0.0898,  0.0996, -0.0209],
          [ 0.0084,  0.0066,  0.0146,  ...,  0.0024, -0.0024,  0.0004]]],
        device='cuda:0', dtype=torch.bfloat16, grad_fn=<MulBackward0>), past_key_values=None, hidden_states=None, attentions=None, cross_attentions=None),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1]]),
 'head_mask': None,
 'decoder_head_mask': None,
 'decoder_attention_mask': None,
 'cross_attn_head_mask': None,
 'use_cache': False}

In [20]:
logits = model(
                            **gen_inputs,
                        ).logits


In [21]:
logits.size()

torch.Size([1, 1, 32128])

In [22]:
logits = logits[:, -1, :]
logits.size()

torch.Size([1, 32128])

In [23]:
logits = torch.nn.functional.softmax(logits, dim=-1)
logits.size()

torch.Size([1, 32128])

In [25]:
true_tokens_ids = tokenizer.encode("True", add_special_tokens=False)
false_tokens_ids = tokenizer.encode("False", add_special_tokens=False)
print(true_tokens_ids)
print(false_tokens_ids)
yes_id = true_tokens_ids[0]
no_id = false_tokens_ids[0]
print(yes_id)
print(no_id)
print(tokenizer.decode([yes_id]))
print(tokenizer.decode([no_id]))

[10998]
[10747, 7, 15]
10998
10747
True
Fal


In [26]:
logits = logits[:, [yes_id, no_id]]
logits

tensor([[0.8580, 0.1017]], grad_fn=<IndexBackward0>)

In [27]:
logits[:, 0] / (logits[:, 0] + logits[:, 1])

tensor([0.8940], grad_fn=<DivBackward0>)