In [11]:
!pip install transformers



In [12]:
import os
import io
import requests
import numpy as np
import pandas as pd
import re
import zipfile
import random
import time
import csv
import datetime
from itertools import compress
from collections import Counter, defaultdict
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from sklearn.metrics.pairwise import cosine_similarity

from transformers import AutoTokenizer, AutoConfig, AutoModelForPreTraining, \
                         AdamW, get_linear_schedule_with_warmup, \
                         TrainingArguments, BeamScorer, Trainer

import torch
from torch.utils.data import Dataset, random_split, DataLoader, \
                             RandomSampler, SequentialSampler

from IPython.display import clear_output

import pickle
from pathlib import Path
from typing import Optional, List, Iterable, Dict, Any,TypeVar,Union
import click
import pandas as pd
import torch
import os
import json
import numpy as np
import torch
from tqdm.auto import tqdm
import torch
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2PreTrainedModel
from transformers.generation_utils import top_k_top_p_filtering
import json
import logging
import math
from functools import partial
from pathlib import Path
from typing import Iterable, List
import torch.multiprocessing as mp
from transformers.pipelines import pipeline

from torch import Tensor
from torch.nn import functional as F


print(f"PyTorch version: {torch.__version__}")


PyTorch version: 1.11.0+cu113


In [13]:
T = TypeVar('T')
def ensure_dir(d):
    if not os.path.exists(d):
        os.makedirs(d)


def batchify(data: Iterable[T], batch_size: int) -> Iterable[List[T]]:
    assert batch_size > 0

    batch = []
    for item in data:
        # Yield next batch
        if len(batch) == batch_size:
            yield batch
            batch = []

        batch.append(item)

    # Yield last un-filled batch
    if len(batch) != 0:
        yield batch


def set_seed(seed, n_gpu):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(seed)


def load_jsonl(file: Union[str, Path]) -> Iterable[Any]:
    with open(file) as f:
        for line in f:
            yield json.loads(line)


def load_cache(file: Path):
    if file.exists():
        with file.open() as f:
            for line in tqdm(f, desc=f'Loading cache from {file}'):
                yield json.loads(line)


#### Generation Classes

In [14]:
MAX_LENGTH = int(10000)  # Hardcoded max length to avoid infinite loop


class GPT2Generation:
    STOP_TOKEN = "<|endoftext|>"

    def __init__(self, model: Union[str, Path, GPT2PreTrainedModel] = 'gpt2', tokenizer: str = 'gpt2', seed: int = 42):
        # Set up device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        n_gpu = torch.cuda.device_count()
        set_seed(seed, n_gpu)

        # Set up model
        if isinstance(model, Path) or isinstance(model, str):
            model = GPT2LMHeadModel.from_pretrained(str(model))
        self.model = model.to(self.device)

        # Set up tokenizer
        # IMPORTANT: Note that setting the pad token like this in the constructor gives the pad_token the
        # pad_token_id = 50256, which normally belongs to the <EOS> token_id in GPT2. This is a very ugly
        # way that works at the moment of setting the pad_token_id to the <EOS> token that is already
        # included in the vocab size.
        self.tokenizer = GPT2Tokenizer.from_pretrained(tokenizer, pad_token=self.STOP_TOKEN)
        assert self.tokenizer.eos_token_id == self.tokenizer.pad_token_id

    def __repr__(self):
        return f'<GPT2Generator model_name_or_path="{self.model}">'

    def __call__(self, *args, **kwargs):
        return self.generate(*args, **kwargs)

    def generate(self,
                 prompt: Union[str, List[str]],
                 max_len: int = 20,
                 sample: bool = True,
                 k: int = 0,
                 p: float = 1.0,
                 temperature: float = 1.0,
                 **model_kwargs) -> List[str]:
        if isinstance(prompt, str):
            prompt = [prompt]

        encodings_dict = self.tokenizer.batch_encode_plus(prompt, pad_to_max_length=True, return_tensors='pt')

        input_ids = encodings_dict['input_ids'].to(self.device)
        attention_mask = encodings_dict['attention_mask'].to(self.device)
        batch_size, input_seq_len = input_ids.shape

        position_ids = attention_mask.cumsum(dim=1) - 1
        unfinished_sents = torch.ones(batch_size, dtype=torch.long, device=self.device)

        self.model.eval()
        with torch.no_grad():
            for step in range(max_len):
                logits, past = self.model(input_ids, attention_mask=attention_mask, position_ids=position_ids,
                                          **model_kwargs)

                # in the first decoding step, we want to use the 'real' last position for each sentence
                if step == 0:
                    last_non_masked_idx = torch.sum(attention_mask, dim=1) - 1
                    next_token_logits = logits[range(batch_size), last_non_masked_idx, :]
                else:
                    next_token_logits = logits[:, -1, :]

                if sample:
                    # Temperature (higher temperature => more likely to sample low probability tokens)
                    if temperature != 1.0:
                        next_token_logits = next_token_logits / temperature
                    # Top-p/top-k filtering
                    next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=k, top_p=p)
                    # Sample
                    probs = F.softmax(next_token_logits, dim=-1)
                    next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
                else:
                    # Greedy decoding
                    next_tokens = torch.argmax(next_token_logits, dim=-1)

                # either append a padding token here if <EOS> has been seen or append next token
                tokens_to_add = next_tokens * unfinished_sents + self.tokenizer.pad_token_id * (1 - unfinished_sents)

                # this updates which sentences have not seen an EOS token so far
                # if one EOS token was seen the sentence is finished
                eos_in_sents = tokens_to_add == self.tokenizer.eos_token_id
                unfinished_sents.mul_((~eos_in_sents).long())

                # stop when there is an EOS in each sentence
                if unfinished_sents.max() == 0:
                    break

                # Update input_ids, attention_mask and position_ids
                input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
                attention_mask = torch.cat([attention_mask, attention_mask.new_ones((batch_size, 1))], dim=1)
                position_ids = torch.cat([position_ids, (position_ids[:, -1] + 1).unsqueeze(-1)], dim=1)

        decoded_outputs = [self.tokenizer.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
                           for output in input_ids[:, input_seq_len:]]
        return decoded_outputs


In [15]:
MAX_LENGTH = int(10000)  # Hardcoded max length to avoid infinite loop
class DExpertsGeneration(GPT2Generation): 
    STOP_TOKEN = "<|endoftext|>"
    def __init__(
        self, 
        base_model: Union[str, Path, GPT2PreTrainedModel],
        antiexpert_model: Union[str, Path, GPT2PreTrainedModel] = None,
        expert_model: Union[str, Path, GPT2PreTrainedModel] = None,
        tokenizer: str = 'gpt2', 
        seed: int = 42,
    ):
        # Set up device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        n_gpu = torch.cuda.device_count()
        set_seed(seed, n_gpu)
        self.base_model = GPT2LMHeadModel.from_pretrained(base_model).to(self.device)
        
        if antiexpert_model:
            self.antiexpert = GPT2LMHeadModel.from_pretrained(antiexpert_model).to(self.device)
        else:
            self.antiexpert = None
        if expert_model:
            self.expert = GPT2LMHeadModel.from_pretrained(expert_model).to(self.device)
        else:
            self.expert = None
        self.tokenizer = GPT2Tokenizer.from_pretrained(tokenizer, pad_token=self.STOP_TOKEN)
        assert self.tokenizer.eos_token_id == self.tokenizer.pad_token_id

    def __repr__(self):
        return f'<DExpertsGenerator model_name_or_path="{self.model}">'

    def generate(self,
                 prompt: Union[str, List[str]],
                 max_len: int = 20,
                 sample: bool = True,
                 filter_p: float = 0.9,
                 k: int = 0,
                 p: float = 1.0,
                 temperature: float = 1.0,
                 alpha: float = 0.0,
                 **model_kwargs):
        if isinstance(prompt, str):
            prompt = [prompt]

        encodings_dict = self.tokenizer.batch_encode_plus(prompt, pad_to_max_length=True, return_tensors='pt')

        input_ids = encodings_dict['input_ids'].to(self.device)
        attention_mask = encodings_dict['attention_mask'].to(self.device)
        batch_size, input_seq_len = input_ids.shape

        position_ids = attention_mask.cumsum(dim=1) - 1
        unfinished_sents = torch.ones(batch_size, dtype=torch.long, device=self.device)

        self.base_model.eval()
        if self.expert:
            self.expert.eval()
        if self.antiexpert:
            self.antiexpert.eval()
        with torch.no_grad():
            for step in range(max_len):
                # base model prediction
                base_logits, base_past = self.base_model(
                    input_ids, attention_mask=attention_mask, position_ids=position_ids,return_dict=False, **model_kwargs)
                
                # expert prediction
                if self.expert:
                    expert_logits, expert_past = self.expert(
                        input_ids, attention_mask=attention_mask, position_ids=position_ids, return_dict=False, **model_kwargs)
                else:
                    expert_logits = base_logits
                
                # antiexpert prediction
                if self.antiexpert:
                    antiexpert_logits, antiexpert_past = self.antiexpert(
                        input_ids, attention_mask=attention_mask, position_ids=position_ids,return_dict=False, **model_kwargs)
                else:
                    antiexpert_logits = base_logits
                
                if filter_p < 1.0:
                    base_logits = top_k_top_p_filtering(base_logits, top_p=filter_p)
                
                # DExperts
                alpha = torch.tensor(alpha).to(self.device)
                ensemble_logits = base_logits + alpha * (expert_logits - antiexpert_logits)

                # in the first decoding step, we want to use the 'real' last position for each sentence
                if step == 0:
                    last_non_masked_idx = torch.sum(attention_mask, dim=1) - 1
                    next_token_logits = ensemble_logits[range(batch_size), last_non_masked_idx, :]
                else:
                    next_token_logits = ensemble_logits[:, -1, :]

                if sample:
                    # Temperature (higher temperature => more likely to sample low probability tokens)
                    if temperature != 1.0:
                        next_token_logits = next_token_logits / temperature
                    if k > 0 or p < 1.0:
                        next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=k, top_p=p)
                    # Sample
                    probs = F.softmax(next_token_logits, dim=-1)
                    next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
                else:
                    # Greedy decoding
                    next_tokens = torch.argmax(next_token_logits, dim=-1)

                # either append a padding token here if <EOS> has been seen or append next token
                tokens_to_add = next_tokens * unfinished_sents + self.tokenizer.pad_token_id * (1 - unfinished_sents)

                # this updates which sentences have not seen an EOS token so far
                # if one EOS token was seen the sentence is finished
                eos_in_sents = tokens_to_add == self.tokenizer.eos_token_id
                unfinished_sents.mul_((~eos_in_sents).long())

                # stop when there is an EOS in each sentence
                if unfinished_sents.max() == 0:
                    break

                # Update input_ids, attention_mask and position_ids
                input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
                attention_mask = torch.cat([attention_mask, attention_mask.new_ones((batch_size, 1))], dim=1)
                position_ids = torch.cat([position_ids, (position_ids[:, -1] + 1).unsqueeze(-1)], dim=1)

        decoded_outputs = [self.tokenizer.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
                           for output in input_ids[:, input_seq_len:]]
        return decoded_outputs


class AntiExpertsGeneration(GPT2Generation): 
    STOP_TOKEN = "<|endoftext|>"

    def __init__(
        self, 
        base_model: Union[str, Path, GPT2PreTrainedModel],
        antiexpert_model: Union[str, Path, GPT2PreTrainedModel] = None,
        tokenizer: str = 'gpt2', 
        seed: int = 42,
    ):
        # Set up device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        n_gpu = torch.cuda.device_count()
        set_seed(seed, n_gpu)
        
        if antiexpert_model:
            self.antiexpert = GPT2LMHeadModel.from_pretrained(antiexpert_model).to(self.device)
        else:
            self.antiexpert = None
        
        
        self.tokenizer = GPT2Tokenizer.from_pretrained(tokenizer, pad_token=self.STOP_TOKEN)
        assert self.tokenizer.eos_token_id == self.tokenizer.pad_token_id

    def __repr__(self):
        return f'<DExpertsGenerator model_name_or_path="{self.model}">'

    def generate(self,
                 prompt: Union[str, List[str]],
                 max_len: int = 20,
                 sample: bool = True,
                 filter_p: float = 0.9,
                 k: int = 0,
                 p: float = 1.0,
                 temperature: float = 1.0,
                 alpha: float = 0.0,
                 **model_kwargs):
        if isinstance(prompt, str):
            prompt = [prompt]

        encodings_dict = self.tokenizer.batch_encode_plus(prompt, pad_to_max_length=True, return_tensors='pt')

        input_ids = encodings_dict['input_ids'].to(self.device)
        attention_mask = encodings_dict['attention_mask'].to(self.device)
        batch_size, input_seq_len = input_ids.shape

        position_ids = attention_mask.cumsum(dim=1) - 1
        unfinished_sents = torch.ones(batch_size, dtype=torch.long, device=self.device)

        if self.antiexpert:
            self.antiexpert.eval()
        with torch.no_grad():
            for step in range(max_len):
                # antiexpert prediction
                if self.antiexpert:
                    antiexpert_logits, antiexpert_past = self.antiexpert(
                        input_ids, attention_mask=attention_mask, position_ids=position_ids,return_dict=False, **model_kwargs)
                
              
                
                # DExperts
                alpha = torch.tensor(alpha).to(self.device)
                # ensemble_logits = base_logits + alpha * (expert_logits - antiexpert_logits)

                # in the first decoding step, we want to use the 'real' last position for each sentence
                if step == 0:
                    last_non_masked_idx = torch.sum(attention_mask, dim=1) - 1
                    next_token_logits = antiexpert_logits[range(batch_size), last_non_masked_idx, :]
                else:
                    next_token_logits = antiexpert_logits[:, -1, :]

                if sample:
                    # Temperature (higher temperature => more likely to sample low probability tokens)
                    if temperature != 1.0:
                        next_token_logits = next_token_logits / temperature
                    if k > 0 or p < 1.0:
                        next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=k, top_p=p)
                    # Sample
                    probs = F.softmax(next_token_logits, dim=-1)
                    next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
                else:
                    # Greedy decoding
                    next_tokens = torch.argmax(next_token_logits, dim=-1)

                # either append a padding token here if <EOS> has been seen or append next token
                tokens_to_add = next_tokens * unfinished_sents + self.tokenizer.pad_token_id * (1 - unfinished_sents)

                # this updates which sentences have not seen an EOS token so far
                # if one EOS token was seen the sentence is finished
                eos_in_sents = tokens_to_add == self.tokenizer.eos_token_id
                unfinished_sents.mul_((~eos_in_sents).long())

                # stop when there is an EOS in each sentence
                if unfinished_sents.max() == 0:
                    break

                # Update input_ids, attention_mask and position_ids
                input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
                attention_mask = torch.cat([attention_mask, attention_mask.new_ones((batch_size, 1))], dim=1)
                position_ids = torch.cat([position_ids, (position_ids[:, -1] + 1).unsqueeze(-1)], dim=1)

        decoded_outputs = [self.tokenizer.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
                           for output in input_ids[:, input_seq_len:]]
        return decoded_outputs


In [16]:
import gdown
gdown.download_folder("https://drive.google.com/drive/folders/1MNKqTPlSc-gDqbiV9hmHc5sWRqeS1Pmi?usp=sharing", quiet=True)

['/content/experts/toxicity/large/finetuned_gpt2_nontoxic/config.json',
 '/content/experts/toxicity/large/finetuned_gpt2_nontoxic/log_history.json',
 '/content/experts/toxicity/large/finetuned_gpt2_nontoxic/merges.txt',
 '/content/experts/toxicity/large/finetuned_gpt2_nontoxic/pytorch_model.bin',
 '/content/experts/toxicity/large/finetuned_gpt2_nontoxic/special_tokens_map.json',
 '/content/experts/toxicity/large/finetuned_gpt2_nontoxic/tokenizer_config.json',
 '/content/experts/toxicity/large/finetuned_gpt2_nontoxic/training_args.bin',
 '/content/experts/toxicity/large/finetuned_gpt2_nontoxic/vocab.json',
 '/content/experts/toxicity/large/finetuned_gpt2_toxic/config.json',
 '/content/experts/toxicity/large/finetuned_gpt2_toxic/log_history.json',
 '/content/experts/toxicity/large/finetuned_gpt2_toxic/merges.txt',
 '/content/experts/toxicity/large/finetuned_gpt2_toxic/pytorch_model.bin',
 '/content/experts/toxicity/large/finetuned_gpt2_toxic/special_tokens_map.json',
 '/content/experts/t

#### Helper Functions

In [17]:
def _gpt2_helper(prompts: pd.Series,
                 max_len: int,
                 num_samples: int,
                 batch_size: int,
                 generator: GPT2Generation,
                 out_file: Path,
                 **generate_kwargs):
    # Repeat prompts
    prompts = prompts.repeat(num_samples)

    # Resume generation
    num_cached_generations = 0
    for generation in load_cache(out_file):
        yield generation
        num_cached_generations += 1

    # Generate with prompts
    prompts = prompts[num_cached_generations:]
    for prompt in tqdm(batchify(prompts, batch_size),
                       total=math.ceil(len(prompts) / batch_size),
                       desc=f'Generation',
                       dynamic_ncols=True,
                       postfix={'batch_size': batch_size}):
        # Generate
        batch = generator.generate(prompt, max_len, **generate_kwargs)

        for generation in batch:
            with out_file.open('a') as f:
                print(json.dumps(prompt[0]+" "+generation), file=f)
            yield generation


def gpt2(prompts: pd.Series,
         max_len: int,
         num_samples: int,
         batch_size: int,
         model_name_or_path: str,
         out_file: Path,
         **generate_kwargs) -> Iterable[str]:
    # Setup model
    generator = GPT2Generation(model_name_or_path)

    yield from _gpt2_helper(prompts=prompts,
                            max_len=max_len,
                            num_samples=num_samples,
                            batch_size=batch_size,
                            generator=generator,
                            out_file=out_file,
                            **generate_kwargs)


def dexperts(prompts: pd.Series,
             max_len: int,
             num_samples: int,
             batch_size: int,
             model_name_or_path: str,
             expert_name_or_path: str,
             antiexpert_name_or_path: str,
             out_file: Path,
             **generate_kwargs) -> Iterable[str]:

    generator = DExpertsGeneration(
        base_model=model_name_or_path, 
        expert_model=expert_name_or_path,
        antiexpert_model=antiexpert_name_or_path
    )

    yield from _gpt2_helper(
        prompts=prompts,
        max_len=max_len,
        num_samples=num_samples,
        batch_size=batch_size,
        generator=generator,
        out_file=out_file,
        **generate_kwargs
    )


def antiexperts(prompts: pd.Series,
             max_len: int,
             num_samples: int,
             batch_size: int,
             model_name_or_path: str,
             antiexpert_name_or_path: str,
             out_file: Path,
             **generate_kwargs) -> Iterable[str]:

    generator = AntiExpertsGeneration(
        base_model=model_name_or_path,
        antiexpert_model=antiexpert_name_or_path
    )

    yield from _gpt2_helper(
        prompts=prompts,
        max_len=max_len,
        num_samples=num_samples,
        batch_size=batch_size,
        generator=generator,
        out_file=out_file,
        **generate_kwargs
    )

In [18]:
def top_k_top_p_filtering(
    logits: Tensor,
    top_k: int = 0,
    top_p: float = 1.0,
    filter_value: float = -float("Inf"),
    min_tokens_to_keep: int = 1,
) -> Tensor:
    """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
    Args:
        logits: logits distribution shape (batch size, vocabulary size)
        if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
        if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
            Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        Make sure we keep at least min_tokens_to_keep per batch example in the output
    From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    if top_k > 0:
        top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))  # Safety check
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
        sorted_indices_to_remove = cumulative_probs > top_p
        if min_tokens_to_keep > 1:
            # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
            sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
        logits[indices_to_remove] = filter_value
    return logits

#### Conditional Non Toxic Generation With DExpert and AntiExpert

In [19]:
OUTPUT_DIR='generations/toxicity/dexperts'
prompt='Start of the sentence'

MODEL_TYPE='dexperts'
MODEL='gpt2-large'
NON_TOXIC_MODEL='experts/toxicity/large/finetuned_gpt2_nontoxic/'
TOXIC_MODEL='experts/toxicity/large/finetuned_gpt2_toxic/'

def main_func(output_dir=OUTPUT_DIR, prompt=prompt, use_eos=False, model=MODEL, model_type=MODEL_TYPE, nontoxic_model=NON_TOXIC_MODEL,
         toxic_model=TOXIC_MODEL, perspective_rate_limit=25, n=5, max_tokens=20, batch_size=32, resume=False,
         alpha=2.0, filter_p=0.9, p=1.0):
    # Load prompts
    prompt = pd.Series([prompt])
    print("Promts Read Successfully")
    print('Prompt:', '\n', prompt)
    # Create output files
    output_dir = Path(output_dir)
    generations_file = output_dir / 'generations.jsonl'
    perspective_file = output_dir / 'perspective.jsonl'
    

    ensure_dir(output_dir)
    output_file = output_dir / f'{"eos" if use_eos else "prompted"}_gens_{model_type}.jsonl'

    
    if model_type == 'gpt2':
        generations_iter = gpt2(
            prompts=prompt,
            max_len=max_tokens,
            num_samples=n,
            p=p,
            batch_size=batch_size,
            model_name_or_path=model,
            out_file=generations_file
        )
    
    elif model_type == 'dexperts':
        generations_iter = dexperts(
            prompts=prompt,
            max_len=max_tokens,
            num_samples=n,
            batch_size=batch_size,
            model_name_or_path=model,
            expert_name_or_path=nontoxic_model,
            antiexpert_name_or_path=toxic_model,
            out_file=generations_file,
            filter_p=filter_p,
            p=p,
            alpha=alpha,
        )
    
    elif model_type == 'antiexperts':
        generations_iter = antiexperts(
            prompts=prompt,
            max_len=max_tokens,
            num_samples=n,
            batch_size=batch_size,
            model_name_or_path=model,
            antiexpert_name_or_path=toxic_model,
            out_file=generations_file,
            filter_p=filter_p,
            p=p,
            alpha=alpha,
        )
    else:
        raise NotImplementedError(f'Model {model} not implemented')

    # Generate
    generations = []
    for i, gen in enumerate(generations_iter):
        generations.append(gen)

    torch.cuda.empty_cache()
    print('Finished generation .....')



In [20]:
class myDataset(Dataset):

    def __init__(self, data, tokenizer, randomize=True):

        title, text, keywords = [], [], []
        for k, v in data.items():
            title.append(v[0])
            text.append(v[1])
            keywords.append(v[2])

        self.randomize = randomize
        self.tokenizer = tokenizer 
        self.title     = title
        self.text      = text
        self.keywords  = keywords  

    #---------------------------------------------#

    @staticmethod
    def join_keywords(keywords, randomize=True):
        N = len(keywords)

        #random sampling and shuffle
        if randomize: 
            M = random.choice(range(N+1))
            keywords = keywords[:M]
            random.shuffle(keywords)

        return ','.join(keywords)

    #---------------------------------------------#

    def __len__(self):
        return len(self.text)

    #---------------------------------------------#
    
    def __getitem__(self, i):
        keywords = self.keywords[i].copy()
        kw = self.join_keywords(keywords, self.randomize)
        
        input = SPECIAL_TOKENS['bos_token'] + self.title[i] + \
                SPECIAL_TOKENS['sep_token'] + kw + SPECIAL_TOKENS['sep_token'] + \
                self.text[i] + SPECIAL_TOKENS['eos_token']

        encodings_dict = tokenizer(input,                                   
                                   truncation=True, 
                                   max_length=MAXLEN, 
                                   padding="max_length")   
        
        input_ids = encodings_dict['input_ids']
        attention_mask = encodings_dict['attention_mask']
        
        return {'label': torch.tensor(input_ids),
                'input_ids': torch.tensor(input_ids), 
                'attention_mask': torch.tensor(attention_mask)}

In [21]:

SPECIAL_TOKENS  = { "bos_token": "<|BOS|>",
                    "eos_token": "<|EOS|>",
                    "unk_token": "<|UNK|>",                    
                    "pad_token": "<|PAD|>",
                    "sep_token": "<|SEP|>"}

title = "We got a lot of grief when our photo became a meme"
keywords = ['train', 'lads', 'picture', 'funny', 'instagram']
kw = myDataset.join_keywords(keywords, randomize=False)

prompt = SPECIAL_TOKENS['bos_token'] + title + \
         SPECIAL_TOKENS['sep_token'] + kw + SPECIAL_TOKENS['sep_token']


main_func(output_dir=OUTPUT_DIR, prompt=prompt, use_eos=False, model=MODEL, model_type=MODEL_TYPE, nontoxic_model=NON_TOXIC_MODEL,
         toxic_model=TOXIC_MODEL, perspective_rate_limit=25, n=5, max_tokens=20, batch_size=1, resume=False,
         alpha=2.0, filter_p=0.9, p=1.0)

Promts Read Successfully
Prompt: 
 0    <|BOS|>We got a lot of grief when our photo be...
dtype: object


Downloading:   0%|          | 0.00/666 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/3.02G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/0.99M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/665 [00:00<?, ?B/s]

Generation:   0%|          | 0/5 [00:00<?, ?it/s, batch_size=1]



Finished generation .....
