In [13]:
import torch
from datasets import load_dataset
from rouge import Rouge
import transformers
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
import transformers
from trainer import Trainer
from torch.utils.data import DataLoader
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
import wandb
from logger import log_metrics
#import gradient checkpointing
from torch.utils.checkpoint import checkpoint_sequential
import numpy as np
from evaluate import load
from transformers import BertTokenizer, BertModel
import torch
import torch.nn.functional as F

class XSumDatasetBERT(torch.utils.data.Dataset):
    def __init__(self, model_name = 'google/pegasus-large', max_length=256, split = 'train'):
        self.tokenizer = PegasusTokenizer.from_pretrained(model_name)
        self.tokenizer.max_length = max_length
        self.model = BertModel.from_pretrained('bert-base-uncased').cuda()
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.dataset = load_dataset("xsum", split = split)
        self.max_length = max_length

    @torch.no_grad()
    def get_bert_embeddings(self, text):
        inputs = self.tokenizer(text, return_tensors="pt", padding=True)
        inputs = {k: v.cuda() for k, v in inputs.items()}
        embeddings = self.model(**inputs)['pooler_output']
        #compute cosine similarity between embeddings
        embeddings = F.normalize(embeddings, p=2, dim=1)
        return embeddings

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        text = self.dataset[idx]['document']
        text = text.split('.')
        text = [i.strip() for i in text]
        
        embeddings = self.get_bert_embeddings(text)
        first = np.random.choice(len(text), 1, replace=False)[0]
        chosen_embeddings = torch.empty((1, 768)).cuda()
        bag_of_sentences = [first]
        chosen_embeddings[0] = embeddings[first]
        current_size = len(text[first])

        while current_size < self.max_length:
            new_cosine_sim = torch.mm(chosen_embeddings, embeddings.T)
            vals, indices = torch.topk(-torch.sum(new_cosine_sim, dim = 0), k = len(text))
            for i in indices:
                if i not in bag_of_sentences:
                    chosen_embeddings = torch.cat((chosen_embeddings, embeddings[i].unsqueeze(0)), dim = 0)
                    bag_of_sentences.append(i.item())
                    current_size += len(text[i])
                    break

        bag_of_sentences = sorted(bag_of_sentences)
        final = [text[i] for i in bag_of_sentences]
        final = '. '.join(final)

        summary_text = self.dataset[idx]['summary']
        return {'article_text':final, 'summary_text': summary_text}

class XSumDatasetPowerLaw(torch.utils.data.Dataset):
    def __init__(self, model_name = 'google/pegasus-large', max_length=256, split = 'train', divisor = 2):
        self.tokenizer = PegasusTokenizer.from_pretrained(model_name)
        self.tokenizer.max_length = max_length
        self.dataset = load_dataset("xsum", split = split)
        self.max_length = max_length
        self.probability = np.ones(1000) * 1000000
        for i, val in enumerate(self.probability):
            if i == 0: continue
            self.probability[i] = self.probability[i-1] / divisor
        self.indexes = np.arange(1000)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        text = self.dataset[idx]['document']
        text = text.split('.')
        
        max_idx = max(1, len(text))
        choices = np.random.choice(self.indexes[:max_idx], max_idx, replace = False, p = self.probability[:max_idx] / self.probability[:max_idx].sum())

        current_size = 0
        counter = 0
        while current_size < self.max_length and counter < max_idx:
            current_size += len(text[choices[counter]])
            counter += 1

        choices = sorted(choices[:counter])
        final = list(np.array(text)[choices])
        text = '. '.join(final)

        summary_text = self.dataset[idx]['summary']
        return {'article_text':text, 'summary_text': summary_text}




dataset1 = XSumDatasetPowerLaw()
dataset2 = XSumDatasetBERT()
dataloader = torch.utils.data.DataLoader(dataset1, batch_size=2, shuffle=True, num_workers=0)



Found cached dataset xsum (/home/da2986/.cache/huggingface/datasets/xsum/default/1.2.0/082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71)
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Fo

In [14]:
dataset1[0]

{'article_text': 'The full cost of damage in Newton Stewart, one of the areas worst affected, is still being assessed. \nRepair work is ongoing in Hawick and many roads in Peeblesshire remain badly affected by standing water. \nJeanette Tate, who owns the Cinnamon Cafe which was badly affected, said she could not fault the multi-agency response once the flood hit',
 'summary_text': 'Clean-up operations are continuing across the Scottish Borders and Dumfries and Galloway after flooding caused by Storm Frank.'}

In [15]:
dataset2[0]

{'article_text': '"That may not be true but it is perhaps my perspective over the last few days. Scottish Borders Council has put a list on its website of the roads worst affected and drivers have been urged not to ignore closure signs. "Obviously it is heart-breaking for people who have been forced out of their homes and the impact on businesses. uk. ',
 'summary_text': 'Clean-up operations are continuing across the Scottish Borders and Dumfries and Galloway after flooding caused by Storm Frank.'}

In [29]:
for i, data in enumerate(dataloader):
    #print(data['article_text'])
    print(data['summary_text'])
    break

['The US says a purported confession from a prominent Chinese lawyer on state television runs counter to the rule of law.', 'The proportion of men taking their own lives in the UK has reached its highest level for more than a decade, according to official figures.']


In [6]:
len(sentence['article_text'])

324

In [27]:
from transformers import BertTokenizer, BertModel
import torch
import torch.nn.functional as F
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased").cuda()
inputs = tokenizer(sentence, return_tensors="pt", padding=True)
inputs = {k: v.cuda() for k, v in inputs.items()}
embeddings = model(**inputs)['pooler_output']
#compute cosine similarity between embeddings
embeddings = F.normalize(embeddings, p=2, dim=1)
first = np.random.choice(len(sentence), 1, replace=False)

new_embeddings = torch.empty((1, 768)).cuda()
bag_of_sentences = [first[0]]
new_embeddings[0] = embeddings[first]
current_size = len(sentence[first[0]])

while current_size < 500:
    new_cosine_sim = torch.mm(new_embeddings, embeddings.T)
    vals, idx = torch.topk(-torch.sum(new_cosine_sim, dim = 0), k = len(sentence))
    for i in idx:
        if i not in bag_of_sentences:
            new_embeddings = torch.cat((new_embeddings, embeddings[i].unsqueeze(0)), dim = 0)
            bag_of_sentences.append(i.item())
            current_size += len(sentence[i])
            break

bag_of_sentences = sorted(bag_of_sentences)
final = [sentence[i] for i in bag_of_sentences]
final = '. '.join(final)
print(final)


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Many businesses and householders were affected by flooding in Newton Stewart after the River Cree overflowed into the town. Scottish Borders Council has put a list on its website of the roads worst affected and drivers have been urged not to ignore closure signs. "Obviously it is heart-breaking for people who have been forced out of their homes and the impact on businesses. Have you been affected by flooding in Dumfries and Galloway or the Borders? Tell us about your experience of the situation and how it was handled. co


In [28]:
!nvidia-smi

Mon Dec  5 15:56:30 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.103.01   Driver Version: 470.103.01   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   48C    P0    61W / 149W |   7798MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces