In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
from functools import partial
import nltk
from src.contextual_bart import ContextualisedBartModel,BartForContextualRecovery
from src.dataset_processor import load_all_data
from src.utils import SmartCollator, get_args, setuptokenizer
from src.dataset_processor import (
    ContextGenerationDataset,
)
from transformers import BartTokenizer, BartConfig,BartForConditionalGeneration
from src.model_utils import CustomTrainer, get_training_arguments
import torch
from src.config import DATASET_PATH
from transformers.trainer_callback import EarlyStoppingCallback
import pickle as pk


nltk.download("punkt")


def generate_tokenizer_and_data(args):

    # load the dataset

    train_data_packet = load_all_data(DATASET_PATH, mode="train")
    test_data_packet = load_all_data(DATASET_PATH, mode="dev")

    print(f"Training Data size: {len(train_data_packet)}")
    print(f"Training Data size: {len(test_data_packet)}")

    model_base = args.model_base
    tokenizer = setuptokenizer(
        model_base=model_base,
        special_tokens=[],
    )
    tokenizer.add_tokens(["[SEP]"])

    train_dataset = ContextGenerationDataset(
        tokenizer=tokenizer, nb_records=len(train_data_packet),
    )
    train_dataset.change_data_mode(1)
    train_dataset.set_record(train_data_packet)

    test_dataset = ContextGenerationDataset(
        tokenizer=tokenizer, nb_records=len(test_data_packet), 
    )
    test_dataset.change_data_mode(1)
    test_dataset.set_record(test_data_packet)

    return train_dataset, test_dataset, [train_data_packet,test_data_packet]



def model_init(
    vocab_size,
    context_delimiter_id,
    model_base="facebook/bart-base",
    device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
):
    bart_config = BartConfig.from_pretrained(model_base)
    bart_config.context_delimiter_id = context_delimiter_id

    generator = BartForContextualRecovery.from_pretrained(
        model_base, config=bart_config, ignore_mismatched_sizes=True
    )

    # update the tokens
    generator.resize_token_embeddings(vocab_size)  # type: ignore
    return generator.to(device)  # type: ignore

[nltk_data] Downloading package punkt to /home/nlplab/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
from dataclasses import dataclass
@dataclass
class Args:
    model_base: str
    
args = Args(model_base="facebook/bart-base")
train_dataset, test_dataset,[train_data_packet,test_data_packet] = generate_tokenizer_and_data(args)
context_delimiter_id = train_dataset.tokenizer.get_added_vocab()['[SEP]']

train_model_path = "trained_models_mtl/bart_base_model_1/checkpoint-45525/pytorch_model.bin"

generator = model_init(len(train_dataset.tokenizer),
                       context_delimiter_id=context_delimiter_id,
                       model_base=args.model_base,)

state_dict = torch.load(train_model_path)
generator.load_state_dict(state_dict)

In [6]:
sample_too = True
sampling_helper = {} if not sample_too else dict(top_k=30, top_p=0.95,)
max_length=250
length_penalty=2.6
beam_size=4
repetition_penalty=1.56
return_top_beams= beam_size if not sample_too else 10

In [7]:
dataset = ContextGenerationDataset(test_dataset.tokenizer,nb_records=1)
dataset.change_data_mode(1)

In [None]:
test_data_packet[0]

In [8]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [9]:
from src.dataset_processor import ContextualGenerationData

In [10]:

data = ContextualGenerationData(input="""Seven (stylized as Se7en) is a 1995 American crime thriller film directed by David Fincher and written by Andrew Kevin Walker. It stars Brad Pitt, Morgan Freeman, Gwyneth Paltrow, and John C. McGinley. [SEP]  Set in a crime-ridden, unnamed city, Seven's plot follows disenchanted, near-retirement detective William Somerset (Freeman) and his new partner, the recently transferred David Mills (Pitt), as they attempt to stop a serial killer before he can complete a series of murders based on the seven deadly sins.
 
                                """.replace("\n","").strip(),output="")

batch = dataset.procesTexts(data)

b_input_ids = batch.input_ids.view(1, -1).to(device)
b_input_mask = batch.attention_mask.view(1, -1).to(device)

sample_too = True
sampling_helper = {} if not sample_too else dict(top_k=25, top_p=0.95,)
return_top_beams = beam_size if not sample_too else 25

In [None]:
data

In [None]:
generator.beam_sample()

In [12]:

# seed_everything(2982)
generator.eval()
with torch.no_grad():
    sample_outputs = generator.generate(input_ids=b_input_ids,  **sampling_helper,
                                        #attention_mask=b_input_mask,
                                        num_beams=beam_size,
                                        repetition_penalty=repetition_penalty,
                                        length_penalty=length_penalty,
                                        early_stopping=False,
                                        use_cache=False,
                                        max_length=max_length,
                                        no_repeat_ngram_size=2,
                                        num_return_sequences=return_top_beams,
                                        do_sample=sample_too,
                                        return_dict_in_generate=False,
                                        eos_token_id=dataset.tokenizer.eos_token_id,)
oop = [dataset.tokenizer.decode(sample_outputs[idx],
                                skip_special_tokens=True,
                                clean_up_tokenization_spaces=True) for idx in range(return_top_beams)]

print(f'Article Section: {data.input} \n')
for q in set(oop):
    print(q)

torch.Size([100, 121])


ValueError: Attention mask should be of size (100, 1, 1, 66), but is torch.Size([1, 1, 1, 66])

In [None]:
from sentence_transformers import SentenceTransformer