In [20]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
from functools import partial
import nltk
from src.contextual_bart import ContextualisedBartModel,BartForContextualRecovery,SimplifiedBeamSearch
from src.dataset_processor import load_all_data
from src.utils import SmartCollator, get_args, setuptokenizer
from src.dataset_processor import (
    ContextGenerationDataset,
)
from transformers import BartTokenizer, BartConfig,BartForConditionalGeneration
from src.model_utils import CustomTrainer, get_training_arguments
import torch
from src.config import DATASET_PATH
from transformers.trainer_callback import EarlyStoppingCallback
import pickle as pk
import torch
from transformers import (    AutoTokenizer,
          AutoModelForSeq2SeqLM,
         LogitsProcessorList,    MinLengthLogitsProcessor, StoppingCriteriaList, MaxLengthCriteria,
         TopKLogitsWarper, TemperatureLogitsWarper,BeamSearchScorer,)

nltk.download("punkt")


def generate_tokenizer_and_data(args):

    # load the dataset

    train_data_packet = load_all_data(DATASET_PATH, mode="train")
    test_data_packet = load_all_data(DATASET_PATH, mode="dev")

    print(f"Training Data size: {len(train_data_packet)}")
    print(f"Training Data size: {len(test_data_packet)}")

    model_base = args.model_base
    tokenizer = setuptokenizer(
        model_base=model_base,
        special_tokens=[],
    )
    tokenizer.add_tokens(["[SEP]"])

    train_dataset = ContextGenerationDataset(
        tokenizer=tokenizer, nb_records=len(train_data_packet),
    )
    train_dataset.change_data_mode(1)
    train_dataset.set_record(train_data_packet)

    test_dataset = ContextGenerationDataset(
        tokenizer=tokenizer, nb_records=len(test_data_packet), 
    )
    test_dataset.change_data_mode(1)
    test_dataset.set_record(test_data_packet)

    return train_dataset, test_dataset, [train_data_packet,test_data_packet]



def model_init(
    vocab_size,
    context_delimiter_id,
    model_base="facebook/bart-base",
    device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
):
    bart_config = BartConfig.from_pretrained(model_base)
    bart_config.context_delimiter_id = context_delimiter_id

    generator = BartForContextualRecovery.from_pretrained(
        model_base, config=bart_config, ignore_mismatched_sizes=True
    )

    # update the tokens
    generator.resize_token_embeddings(vocab_size)  # type: ignore
    return generator.to(device)  # type: ignore

ImportError: cannot import name 'SimplifiedBeamSearch' from 'src.contextual_bart' (/media/nlplab/hdd2/Laith/jojos_work/contextualised_sentence_embedding/src/contextual_bart.py)

In [2]:
from dataclasses import dataclass
@dataclass
class Args:
    model_base: str
    
args = Args(model_base="facebook/bart-base")
train_dataset, test_dataset,[train_data_packet,test_data_packet] = generate_tokenizer_and_data(args)
context_delimiter_id = train_dataset.tokenizer.get_added_vocab()['[SEP]']

train_model_path = "trained_models_mtl/bart_base_model_1/checkpoint-45525/pytorch_model.bin"

generator = model_init(len(train_dataset.tokenizer),
                       context_delimiter_id=context_delimiter_id,
                       model_base=args.model_base,)

state_dict = torch.load(train_model_path)
generator.load_state_dict(state_dict)

processing files:  ['processed_data/context_generation_train.csv']
processing files:  ['processed_data/context_generation_dev.csv']
Training Data size: 145670
Training Data size: 12151


<All keys matched successfully>

In [3]:
dataset = ContextGenerationDataset(test_dataset.tokenizer,nb_records=1)
dataset.change_data_mode(1)

In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")


In [16]:
from src.dataset_processor import ContextualGenerationData
data = ContextualGenerationData(input="""Seven (stylized as Se7en) is a 1990 American crime thriller film directed by David Fincher and written by Andrew Kevin Walker. It stars Brad Pitt, Morgan Freeman, Gwyneth Paltrow, and John C. McGinley. [SEP]  Set in a crime-ridden, unnamed city, Seven's plot follows disenchanted, near-retirement detective William Somerset (Freeman) and his new partner, the recently transferred David Mills (Pitt), as they attempt to stop a serial killer before he can complete a series of murders based on the seven deadly sins.
 
                                """.replace("\n","").strip(),output="")

batch = dataset.procesTexts(data)

In [17]:
b_input_ids = batch.input_ids.view(1, -1).to(device)
b_input_mask = batch.attention_mask.view(1, -1).to(device)

In [19]:
bb= SimplifiedBeamSearch(generator,dataset.tokenizer)
bb.generate(input_ids=b_input_ids,attention_mask=b_input_mask,)

['Seven (stylized as Se7en) is a 1990 American crime thriller film directed by David Fincher and written by Andrew Kevin Walker. It stars Brad Pitt, Morgan Freeman, Gwyneth Paltrow, and John C. McGinley.']

In [None]:
# lets run beam search using 3 beams
num_beams = 3
input_ids = torch.ones((num_beams, 1), device=generator.device, dtype=torch.long)
input_ids = input_ids * generator.config.decoder_start_token_id

In [None]:
model_kwargs = {   "encoder_outputs": generator.get_encoder()( b_input_ids.repeat_interleave(num_beams,dim=0), b_input_mask.repeat_interleave(num_beams,dim=0), return_dict=True
        ) }

In [None]:
b_input_ids.shape

In [None]:
beam_scorer = BeamSearchScorer(  batch_size=4,max_length=generator.config.max_length,num_beams=num_beams,device=generator.device,)

In [None]:
logits_processor = LogitsProcessorList(   [MinLengthLogitsProcessor(5, eos_token_id=generator.config.eos_token_id)] )

logits_warper = LogitsProcessorList( [   TopKLogitsWarper(50),    TemperatureLogitsWarper(0.7), ])

In [None]:
outputs = generator.beam_sample(input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs
        )


In [None]:
dataset.tokenizer.batch_decode(outputs, skip_special_tokens=True)