In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "2"
from functools import partial
import nltk
from src.contextual_bart import ContextualisedBartModel,BartForContextualRecovery,SimplifiedBeamSearch
from src.dataset_processor import load_all_data
from src.utils import SmartCollator, get_args, setuptokenizer
from src.dataset_processor import (
    ContextGenerationDataset,
)
from transformers import BartTokenizer, BartConfig,BartForConditionalGeneration
from src.model_utils import CustomTrainer, get_training_arguments
import torch
from src.config import DATASET_PATH
from transformers.trainer_callback import EarlyStoppingCallback
import pickle as pk
import torch
from transformers import (    AutoTokenizer,
          AutoModelForSeq2SeqLM,
         LogitsProcessorList,    MinLengthLogitsProcessor, StoppingCriteriaList, MaxLengthCriteria,
         TopKLogitsWarper, TemperatureLogitsWarper,BeamSearchScorer,)

nltk.download("punkt")

DATASET_PATH = "summarisation_data/"

def generate_tokenizer_and_data(args):

    # load the dataset

    train_data_packet = load_all_data(DATASET_PATH, mode="train")
    dev_data_packet = load_all_data(DATASET_PATH, mode="dev")
    test_data_packet = load_all_data(DATASET_PATH,mode="test")

    print(f"Training Data size: {len(train_data_packet)}")
    print(f"Training Data size: {len(test_data_packet)}")

    model_base = args.model_base
    tokenizer = setuptokenizer(
        model_base=model_base,
        special_tokens=[],
    )
    tokenizer.add_tokens([args.sep_token])

    train_dataset = ContextGenerationDataset(
        tokenizer=tokenizer, nb_records=len(train_data_packet), max_len=720,
        context_seperator=args.sep_token,
        is_auto_encoder_data=not args.is_not_auto_encoder_data,
        use_special_token=True,
    )
    train_dataset.change_data_mode(1)
    train_dataset.set_record(train_data_packet)

    test_dataset = ContextGenerationDataset(
        tokenizer=tokenizer, nb_records=len(test_data_packet), 
        max_len=700,
        context_seperator=args.sep_token,
        is_auto_encoder_data=not args.is_not_auto_encoder_data,
    )
    test_dataset.change_data_mode(1)
    test_dataset.set_record(test_data_packet)
    
    dev_dataset = ContextGenerationDataset(
        tokenizer=tokenizer, nb_records=len(dev_data_packet), 
        max_len=700,
        context_seperator=args.sep_token,
        is_auto_encoder_data=not args.is_not_auto_encoder_data,
    )
    test_dataset.change_data_mode(1)
    test_dataset.set_record(test_data_packet)

    return train_dataset, dev_dataset,test_dataset, [train_data_packet,dev_data_packet,test_data_packet]



def model_init(
    vocab_size,
    context_delimiter_id,
    model_base="facebook/bart-base",
    use_random_restriction=False,
    section_prob=(0.25, 0.45),
    device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
):
    def build_model():
        bart_config = BartConfig.from_pretrained(model_base)
        bart_config.context_delimiter_id = context_delimiter_id
        bart_config.use_random_restriction = use_random_restriction
        bart_config.section_prob = section_prob

        generator = BartForContextualRecovery.from_pretrained(
            model_base, config=bart_config, ignore_mismatched_sizes=True
        )

        # update the tokens
        generator.resize_token_embeddings(vocab_size)  # type: ignore
        return generator.to(device)  # type: ignore
    return build_model

[nltk_data] Downloading package punkt to /home/nlplab/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
from dataclasses import dataclass
@dataclass
class Args:
    model_base: str
    sep_token: str = "[SEP]"
    is_not_auto_encoder_data: bool = True
    
    
args = Args(model_base="facebook/bart-base")
train_dataset, dev_dataset,test_dataset, [train_data_packet,dev_data_packet,test_data_packet] = generate_tokenizer_and_data(args)

processing files:  ['summarisation_data/xsum_train.csv']
processing files:  ['summarisation_data/xsum_dev.csv']
processing files:  ['summarisation_data/xsum_test.csv']
Training Data size: 162548
Training Data size: 9049
The model will be trained as a non auto-encoder
The model will be trained as a non auto-encoder
The model will be trained as a non auto-encoder


In [3]:
context_delimiter_id = train_dataset.tokenizer.get_vocab()['[SEP]']

train_model_path = "trained_models_sum/bart_base_model_full/checkpoint-81275/pytorch_model.bin"
#"trained_models_mtl/bart_base_model_full/checkpoint-263195/pytorch_model.bin"

generator = model_init(len(train_dataset.tokenizer),
                       context_delimiter_id=context_delimiter_id,
                       model_base=args.model_base,use_random_restriction=False)()

state_dict = torch.load(train_model_path)
generator.load_state_dict(state_dict)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [4]:
dataset = ContextGenerationDataset(test_dataset.tokenizer,
                                   nb_records=1,
                                   section_boundary=(0.4,0.48),
                                   
        context_seperator=args.sep_token,
        is_auto_encoder_data=not args.is_not_auto_encoder_data,
                                   use_random_restrictive=True)
dataset.change_data_mode(1)

The model will be trained as a non auto-encoder


In [8]:
from src.dataset_processor import ContextualGenerationData
from pytorch_lightning import seed_everything
data = ContextualGenerationData(input="""
                                We are helping the community work together towards the goal of advancing Machine Learning 🔥.
The Hugging Face Hub is a platform with over 60K models, 6K datasets, and 6K demos in which people can easily collaborate in their ML workflows. 
The Hub works as a central place where anyone can share, explore, discover, and experiment with open-source Machine Learning.
 No single company, including the Tech Titans, will be able to “solve AI” by themselves - the only way we'll achieve this is by sharing knowledge and resources in a community-centric approach. We are building the largest open-source collection of models, datasets, demos and metrics on the Hugging Face Hub to democratize and advance ML for everyone 🚀.
                                """.replace("\n","").strip(),output="")
kk= 45
batch = dataset.procesTexts(data)
b_input_ids = batch.input_ids.view(1, -1).to(device)
b_input_mask = batch.attention_mask.view(1, -1).to(device)
batch.section_point, b_input_ids.shape

(50, torch.Size([1, 158]))

In [6]:
test_data_packet[kk].output

'Calls have been made for a room in Wrexham where heroin users can inject safely under supervision.'

In [13]:
test_data_packet[0].output

'Wigan Athletic have signed former Manchester United midfielder Nick Powell on a three-year contract.'

In [18]:
from torch.utils.data import DataLoader,SequentialSampler
test_data_loader = DataLoader(test_dataset,batch_size=10,
                              sampler= SequentialSampler(test_dataset),
                              collate_fn= SmartCollator(
            pad_token_id=train_dataset.tokenizer.pad_token_id, max_len=700
        )
                              )

In [21]:
import tqdm
output_summaries =[]
for batch in tqdm.tqdm(test_data_loader):
    b_input_ids = batch['input_ids'].to(device)
    b_input_mask = batch['attention_mask'].to(device)
    bb=generator.generate(input_ids=b_input_ids,
            attention_mask=b_input_mask,
            num_beams=10,
            do_sample=False,
            num_return_sequences=1,
            max_new_tokens=320)
    
    sentences = test_dataset.tokenizer.batch_decode(bb,
                                                    clean_up_tokenization_spaces=True,
                                                    skip_special_tokens=True)
    output_summaries+=sentences

torch.Size([10, 593])


In [16]:
bb=generator.generate(input_ids=b_input_ids,
            attention_mask=b_input_mask,
            num_beams=10,
            do_sample=False,
            num_return_sequences=1,
            max_new_tokens=320)
test_dataset.tokenizer.batch_decode(bb,clean_up_tokenization_spaces=True,skip_special_tokens=True)

['The Hugging Face Hub, an open-source machine learning hub, has been launched in New York.']