In [1]:
from transformers import AdamW, T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments, get_scheduler
from datasets import load_dataset, load_metric
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import gc
from tqdm.auto import tqdm
import random
from copy import deepcopy
from torch.autograd import variable

In [2]:
tokenizer = T5Tokenizer.from_pretrained('t5-small')
model = T5ForConditionalGeneration.from_pretrained('t5-small')

In [3]:
device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
device

device(type='cuda')

In [4]:
model.to(device)

T5ForConditionalGeneration(
  (shared): Embedding(32128, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 8)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseReluDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Dro

In [5]:
torch.cuda.memory_allocated(device=device)

242026496

In [6]:
squad_raw = load_dataset('squad')
squad_raw

Reusing dataset squad (C:\Users\Pranav\.cache\huggingface\datasets\squad\plain_text\1.0.0\d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453)


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))




DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})

In [7]:
def preprocess_squad(examples):
    return {'id': examples['id'],
            'src':
                ['question: ' + question + ' context: ' + context
                 for question, context in zip(examples['question'], examples['context'])
                ],
            'trg': [answer['text'][0] for answer in examples['answers']],
            'answers': examples['answers']
           }

In [8]:
squad = squad_raw.map(preprocess_squad, batched=True, remove_columns=squad_raw['train'].column_names)

Loading cached processed dataset at C:\Users\Pranav\.cache\huggingface\datasets\squad\plain_text\1.0.0\d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453\cache-961250d0cdb39e2a.arrow
Loading cached processed dataset at C:\Users\Pranav\.cache\huggingface\datasets\squad\plain_text\1.0.0\d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453\cache-e448ae214532babe.arrow


In [9]:
squad['validation']

Dataset({
    features: ['answers', 'id', 'src', 'trg'],
    num_rows: 10570
})

In [None]:
squad_metric = load_metric("squad")

In [None]:
def eval_squad(model):
    progress_bar = tqdm(range(len(squad['validation'])))

    model.eval()
    b_size = 32
    start = 0
    end = b_size
    exists = True
    while exists:
        examples = squad['validation'][start: end]
        input_ids = tokenizer(examples['src'], return_tensors="pt", padding=True).input_ids.to(device)
        with torch.no_grad():
            outputs = model.generate(input_ids=input_ids)

        decoded = [{'id': ide, 'prediction_text': tokenizer.decode(outputs[i], skip_special_tokens=True)} for i, ide in zip(range(len(examples['id'])), examples['id'])]
        refs = [{'id': examples['id'][i], 'answers': ans} for i, ans in enumerate(examples['answers'])]
        squad_metric.add_batch(predictions=decoded, references=refs)
        progress_bar.update(len(examples['id']))
        start = end
        end = -1 if end + b_size >= len(squad['validation']) else end + b_size
        if start == -1:
            break

    return squad_metric.compute()

In [None]:
eval_squad(model)

In [10]:
raw_paraphrases = load_dataset('paws', 'labeled_final')

Reusing dataset paws (C:\Users\Pranav\.cache\huggingface\datasets\paws\labeled_final\1.1.0\09d8fae989bb569009a8f5b879ccf2924d3e5cd55bfe2e89e6dab1c0b50ecd34)


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))




In [11]:
paraphrases = raw_paraphrases.filter(lambda example: example['label'] == 1)
paraphrases

Loading cached processed dataset at C:\Users\Pranav\.cache\huggingface\datasets\paws\labeled_final\1.1.0\09d8fae989bb569009a8f5b879ccf2924d3e5cd55bfe2e89e6dab1c0b50ecd34\cache-f8f4b05c63816056.arrow
Loading cached processed dataset at C:\Users\Pranav\.cache\huggingface\datasets\paws\labeled_final\1.1.0\09d8fae989bb569009a8f5b879ccf2924d3e5cd55bfe2e89e6dab1c0b50ecd34\cache-9bfc9aaee2aa1e33.arrow
Loading cached processed dataset at C:\Users\Pranav\.cache\huggingface\datasets\paws\labeled_final\1.1.0\09d8fae989bb569009a8f5b879ccf2924d3e5cd55bfe2e89e6dab1c0b50ecd34\cache-274702c64a43f78c.arrow


DatasetDict({
    train: Dataset({
        features: ['id', 'sentence1', 'sentence2', 'label'],
        num_rows: 21829
    })
    test: Dataset({
        features: ['id', 'sentence1', 'sentence2', 'label'],
        num_rows: 3536
    })
    validation: Dataset({
        features: ['id', 'sentence1', 'sentence2', 'label'],
        num_rows: 3539
    })
})

In [12]:
def insert_task_name(examples):
    return {
        'sentence1': ['paraphrase: ' + example for example in examples['sentence1']] + ['paraphrase: ' + example for example in examples['sentence2']],
        'sentence2': [example for example in examples['sentence2']] + [example for example in examples['sentence1']]
    }

In [13]:
paraphrases = paraphrases.map(insert_task_name, batched=True, remove_columns=paraphrases['train'].column_names)

Loading cached processed dataset at C:\Users\Pranav\.cache\huggingface\datasets\paws\labeled_final\1.1.0\09d8fae989bb569009a8f5b879ccf2924d3e5cd55bfe2e89e6dab1c0b50ecd34\cache-37d241fc255f3358.arrow
Loading cached processed dataset at C:\Users\Pranav\.cache\huggingface\datasets\paws\labeled_final\1.1.0\09d8fae989bb569009a8f5b879ccf2924d3e5cd55bfe2e89e6dab1c0b50ecd34\cache-60b05b6bb815b979.arrow
Loading cached processed dataset at C:\Users\Pranav\.cache\huggingface\datasets\paws\labeled_final\1.1.0\09d8fae989bb569009a8f5b879ccf2924d3e5cd55bfe2e89e6dab1c0b50ecd34\cache-8f7ea547eff87668.arrow


In [14]:
paraphrases['train'][0], paraphrases['train'][1]

({'sentence1': 'paraphrase: The NBA season of 1975 -- 76 was the 30th season of the National Basketball Association .',
  'sentence2': 'The 1975 -- 76 season of the National Basketball Association was the 30th season of the NBA .'},
 {'sentence1': 'paraphrase: When comparable rates of flow can be maintained , the results are high .',
  'sentence2': 'The results are high when comparable flow rates can be maintained .'})

In [15]:
train_loader = DataLoader(paraphrases['train'], shuffle=True, batch_size=8)
eval_loader = DataLoader(paraphrases['validation'], batch_size=8)

In [16]:
optimizer = AdamW(model.parameters(), lr=5e-5)

In [17]:
num_epochs = 3
num_training_steps = num_epochs * len(train_loader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

In [18]:
%load_ext line_profiler
def train():
    progress_bar = tqdm(range(num_training_steps))

    model.train()
    for epoch in range(num_epochs):
        for batch in train_loader:
            input_ids = tokenizer(batch['sentence1'], return_tensors="pt", padding=True).input_ids.to(device)
            labels = tokenizer(batch['sentence2'], return_tensors="pt", padding=True).input_ids.to(device)
            outputs = model(input_ids=input_ids, labels=labels)
            print(outputs.loss)
            outputs.loss.backward()
            print(torch.all(torch.tensor([torch.all(p.grad == 0) for n, p, in model.named_parameters()])))
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            del input_ids
            del labels
            del outputs
            progress_bar.update(1)

In [19]:
%lprun -f train train()

HBox(children=(FloatProgress(value=0.0, max=16374.0), HTML(value='')))

tensor(2.3593, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(False)
tensor(1.9244, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(False)
tensor(1.6482, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(False)
tensor(2.2818, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(False)
tensor(2.0558, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(False)
tensor(2.0197, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(False)
tensor(1.8280, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(False)
tensor(2.1202, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(False)
tensor(2.0513, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(False)
tensor(1.7575, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(False)
tensor(1.6337, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(False)
tensor(2.5335, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(False)
tensor(1.9168, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(False)
tensor(1.3590, device='cuda:0', grad_fn=<NllLossBac

In [32]:
class EWC:
    def __init__(self, model: nn.Module, dataset: list):

        self.model = model
        self.dataset = dataset

        self.params = {n: p for n, p in self.model.named_parameters() if p.requires_grad}
        self._means = {}
        self._precision_matrices = self._diag_fisher()

        for n, p in deepcopy(self.params).items():
            self._means[n] = p.data.clone().detach().requires_grad_(True)

    def _diag_fisher(self):
        precision_matrices = {}
        for n, p in deepcopy(self.params).items():
            p.data.zero_()
            precision_matrices[n] = p.data.clone().detach().requires_grad_(True)

        for src, trg in self.dataset:
            self.model.zero_grad()
            input_ids = tokenizer(src, return_tensors="pt", padding=True).input_ids.to(device)
            labels = tokenizer(trg, return_tensors="pt", padding=True).input_ids.to(device)
            output = self.model(input_ids=input_ids, labels=labels).logits
            loss = F.nll_loss(F.log_softmax(output, dim=2).squeeze(), labels.squeeze())
            loss.backward()
            del input_ids
            del labels
            del output

            for n, p in self.model.named_parameters():
                precision_matrices[n].data += p.grad.data ** 2 / len(self.dataset)

        precision_matrices = {n: p for n, p in precision_matrices.items()}
        return precision_matrices

    def penalty(self, model: nn.Module):
        loss = 0
        for n, p in model.named_parameters():
            _loss = self._precision_matrices[n] * (p - self._means[n]) ** 2
            print(p - self._means[n])
            loss += _loss.sum()
        return loss

In [33]:
samples = random.sample(list(zip(squad['train']['src'], squad['train']['trg'])), 20)

In [34]:
ewc = EWC(model, samples)

In [31]:
ewc._precision_matrices

{'shared.weight': tensor([[3.5317e-07, 1.2976e-06, 5.1204e-08,  ..., 1.5378e-08, 1.8123e-07,
          4.3740e-07],
         [2.7781e-07, 3.8440e-07, 1.4131e-07,  ..., 3.5741e-07, 1.2566e-08,
          3.4311e-07],
         [3.8527e-11, 6.6937e-11, 3.7453e-11,  ..., 2.1120e-11, 6.6215e-12,
          2.5619e-11],
         ...,
         [2.2584e-40, 9.8149e-40, 9.2487e-41,  ..., 2.7520e-40, 5.1848e-44,
          2.1676e-40],
         [1.6189e-40, 6.7725e-40, 6.1530e-41,  ..., 1.8451e-40, 4.7644e-44,
          1.4633e-40],
         [1.4914e-40, 6.0309e-40, 5.4145e-41,  ..., 1.6387e-40, 4.7644e-44,
          1.3058e-40]], device='cuda:0', requires_grad=True),
 'encoder.block.0.layer.0.SelfAttention.q.weight': tensor([[1.9472e-05, 4.6931e-06, 1.3818e-05,  ..., 3.2332e-06, 3.6220e-06,
          1.8747e-05],
         [1.7118e-05, 4.2751e-06, 5.5334e-06,  ..., 1.1947e-06, 5.7039e-06,
          1.1829e-05],
         [7.0245e-05, 3.0151e-06, 6.9561e-06,  ..., 6.2099e-06, 1.5882e-05,
          1.

In [39]:
for n, p in model.named_parameters():
    print(p - p.data)
    break

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0',
       grad_fn=<SubBackward0>)


In [35]:
ewc.penalty(model)

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0',
       grad_fn=<SubBackward0>)
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0',
       grad_fn=<SubBackward0>)
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0',
       grad_fn=<SubBackward0>)
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0

tensor(0., device='cuda:0', grad_fn=<AddBackward0>)

In [None]:
src, trg = squad['train']['src'][0], squad['train']['trg'][0]
src, trg

In [None]:
model.train()
input_ids = tokenizer(src, return_tensors="pt", padding=True).input_ids.to(device)
labels = tokenizer(trg, return_tensors="pt", padding=True).input_ids.to(device)
output = model(input_ids=input_ids, labels=labels)
#loss = F.nll_loss(F.log_softmax(output.logits, dim=2).squeeze(), labels.squeeze())
output.loss.backward()

In [None]:
torch.all(torch.tensor([torch.all(p.grad == 0) for n, p, in model.named_parameters()]))

In [None]:
model.eval()
sentence = "paraphrase: From the merger of the Four Rivers Council and the Audubon Council , the Shawnee Trails Council was born."

encoding = tokenizer.encode_plus(sentence,pad_to_max_length=True, return_tensors="pt")
input_ids, attention_masks = encoding["input_ids"].to(device), encoding["attention_mask"].to(device)

with torch.no_grad():
    beam_outputs = model.generate(
        input_ids=input_ids, attention_mask=attention_masks,
        do_sample=True,
        max_length=256,
        top_k=120,
        top_p=0.98,
        early_stopping=True,
        num_return_sequences=10
    )
    print ("\nOriginal:")
    print (sentence)
    print ("\n")
    print ("Paraphrases:")
    final_outputs =[]
    for beam_output in beam_outputs:
        sent = tokenizer.decode(beam_output, skip_special_tokens=True,clean_up_tokenization_spaces=True)
        if sent.lower() != sentence.lower() and sent not in final_outputs:
            final_outputs.append(sent)

    for i, final_output in enumerate(final_outputs):
        print("{}: {}".format(i, final_output))

In [None]:
eval_squad(model)

In [None]:
paraphrases['train']['sentence1'][0][12:]

In [None]:
model.eval()
sentence = paraphrases['train']['sentence1'][0]

encoding = tokenizer.encode_plus(sentence,pad_to_max_length=True, return_tensors="pt")
input_ids, attention_masks = encoding["input_ids"].to(device), encoding["attention_mask"].to(device)

with torch.no_grad():
    beam_outputs = model.generate(
        input_ids=input_ids, attention_mask=attention_masks,
        do_sample=True,
        max_length=256,
        top_k=120,
        top_p=0.98,
        early_stopping=True,
        num_return_sequences=10
    )
    print ("\nOriginal:")
    print (sentence)
    print ("\n")
    print ("Paraphrases:")
    final_outputs =[]
    for beam_output in beam_outputs:
        sent = tokenizer.decode(beam_output, skip_special_tokens=True,clean_up_tokenization_spaces=True)
        if sent.lower() != sentence.lower() and sent not in final_outputs:
            final_outputs.append(sent)

    for i, final_output in enumerate(final_outputs):
        print("{}: {}".format(i, final_output))

In [None]:
gc.collect()