# T5

- 📺 **Video:** [https://youtu.be/b6KFaT8mK4g](https://youtu.be/b6KFaT8mK4g)

## Overview
- T5 treats every NLP task as text-to-text, enabling multi-task pre-training.
- Unify tasks via natural-language prompts and sequence-to-sequence modeling.

## Key ideas
- **Text-to-text:** both inputs and outputs are text strings.
        - **Span corruption:** mask spans with sentinel tokens during pre-training.
- **Multi-task mixture:** train on translation, summarization, QA simultaneously.
- **Prompting:** task prefix instructs the model what to do.

## Demo
Construct a simple prompt-driven encoder-decoder that maps arithmetic prompts to textual answers, echoing the lecture (https://youtu.be/Mrc4Bcr90VA).

In [1]:
import torch
from torch import nn

prompts = ['add: 1 + 2', 'add: 3 + 3', 'subtract: 5 - 2', 'add: 2 + 4']
answers = ['3', '6', '3', '6']

characters = set('0123456789') | set(' '.join(prompts + answers))
char_vocab = sorted(characters) + ['<pad>', '<s>', '</s>']
char_to_id = {c: i for i, c in enumerate(char_vocab)}

max_src = max(len(p) for p in prompts) + 2
max_tgt = max(len(a) for a in answers) + 2

def encode(text, max_len):
    ids = [char_to_id['<s>']] + [char_to_id[c] for c in text] + [char_to_id['</s>']]
    ids += [char_to_id['<pad>']] * (max_len - len(ids))
    return ids

src = torch.tensor([encode(p, max_src) for p in prompts])
tgt = torch.tensor([encode(a, max_tgt) for a in answers])

embed = nn.Embedding(len(char_vocab), 32)
encoder = nn.GRU(32, 32, batch_first=True)
decoder = nn.GRU(32, 32, batch_first=True)
out = nn.Linear(32, len(char_vocab))
criterion = nn.CrossEntropyLoss(ignore_index=char_to_id['<pad>'])
optimizer = torch.optim.Adam(list(embed.parameters()) + list(encoder.parameters()) + list(decoder.parameters()) + list(out.parameters()), lr=5e-3)

for epoch in range(1, 201):
    enc_in = embed(src)
    _, hidden = encoder(enc_in)
    dec_in = embed(tgt[:, :-1])
    outputs, _ = decoder(dec_in, hidden)
    logits = out(outputs)
    loss = criterion(logits.reshape(-1, len(char_vocab)), tgt[:, 1:].reshape(-1))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if epoch % 50 == 0:
        print(f"epoch {epoch:3d} | loss {loss.item():.4f}")

with torch.no_grad():
    for prompt in ['add: 4 + 1', 'subtract: 7 - 5']:
        src_vec = torch.tensor([encode(prompt, max_src)])
        enc_in = embed(src_vec)
        _, hidden = encoder(enc_in)
        dec_input = torch.tensor([[char_to_id['<s>']]])
        result = []
        hidden_state = hidden
        for _ in range(max_tgt):
            dec_emb = embed(dec_input)
            out_step, hidden_state = decoder(dec_emb, hidden_state)
            logits = out(out_step.squeeze(1))
            next_id = logits.argmax(dim=-1)
            token = next_id.item()
            if token == char_to_id['</s>']:
                break
            if token != char_to_id['<pad>']:
                result.append(char_vocab[token])
            dec_input = next_id.unsqueeze(1)
        print(f"Prompt '{prompt}' ->", ''.join(result))


epoch  50 | loss 0.2398
epoch 100 | loss 0.0310
epoch 150 | loss 0.0064


epoch 200 | loss 0.0037
Prompt 'add: 4 + 1' -> 6
Prompt 'subtract: 7 - 5' -> 3


## Try it
- Modify the demo
- Add a tiny dataset or counter-example


## References
- [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/pdf/1810.04805.pdf)
- [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/pdf/1810.04805.pdf)
- [To Tune or Not to Tune? Adapting Pretrained Representations to Diverse Tasks](https://www.aclweb.org/anthology/W19-4302/)
- [GLUE: A Multi-Task Benchmark and Analysis Platform for Natural Language Understanding](https://arxiv.org/pdf/1804.07461.pdf)
- [What Does BERT Look At? An Analysis of BERT's Attention](https://arxiv.org/abs/1906.04341)
- [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/pdf/1907.11692.pdf)
- [BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/abs/1910.13461)
- [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/pdf/1910.10683.pdf)
- [UnifiedQA: Crossing Format Boundaries With a Single QA System](https://arxiv.org/abs/2005.00700)
- [Neural Machine Translation of Rare Words with Subword Units](https://arxiv.org/pdf/1508.07909.pdf)
- [Byte Pair Encoding is Suboptimal for Language Model Pretraining](https://arxiv.org/pdf/2004.03720.pdf)
- [Eisenstein 8.1](https://github.com/jacobeisenstein/gt-nlp-class/blob/master/notes/eisenstein-nlp-notes.pdf)
- [Eisenstein 7.1](https://github.com/jacobeisenstein/gt-nlp-class/blob/master/notes/eisenstein-nlp-notes.pdf)
- [Eisenstein 7.4](https://github.com/jacobeisenstein/gt-nlp-class/blob/master/notes/eisenstein-nlp-notes.pdf)
- [Eisenstein 7.4.1](https://github.com/jacobeisenstein/gt-nlp-class/blob/master/notes/eisenstein-nlp-notes.pdf)
- [Eisenstein 7.3](https://github.com/jacobeisenstein/gt-nlp-class/blob/master/notes/eisenstein-nlp-notes.pdf)
- [TnT - A Statistical Part-of-Speech Tagger](https://arxiv.org/abs/cs/0003055)
- [Enriching the Knowledge Sources Used in a Maximum Entropy Part-of-Speech Tagger](https://www.aclweb.org/anthology/W00-1308/)
- [Part-of-Speech Tagging from 97% to 100%: Is It Time for Some Linguistics?](https://link.springer.com/chapter/10.1007/978-3-642-19400-9_14)
- [Natural Language Processing with Small Feed-Forward Networks](https://www.aclweb.org/anthology/D17-1309.pdf)
- [Eisenstein 10.1-10.2](https://github.com/jacobeisenstein/gt-nlp-class/blob/master/notes/eisenstein-nlp-notes.pdf)
- [Eisenstein 10.3-10.4](https://github.com/jacobeisenstein/gt-nlp-class/blob/master/notes/eisenstein-nlp-notes.pdf)
- [Eisenstein 10.3.1](https://github.com/jacobeisenstein/gt-nlp-class/blob/master/notes/eisenstein-nlp-notes.pdf)
- [Accurate Unlexicalized Parsing](https://www.aclweb.org/anthology/P03-1054/)
- [Eisenstein 10.5](https://github.com/jacobeisenstein/gt-nlp-class/blob/master/notes/eisenstein-nlp-notes.pdf)
- [Eisenstein 11.1](https://github.com/jacobeisenstein/gt-nlp-class/blob/master/notes/eisenstein-nlp-notes.pdf)
- [Finding Optimal 1-Endpoint-Crossing Trees](https://www.aclweb.org/anthology/Q13-1002/)
- [Eisenstein 11.3](https://github.com/jacobeisenstein/gt-nlp-class/blob/master/notes/eisenstein-nlp-notes.pdf)


*Links only; we do not redistribute slides or papers.*