In [1]:
!pip install GPUtil


Collecting GPUtil
  Downloading GPUtil-1.4.0.tar.gz (5.5 kB)
  Preparing metadata (setup.py) ... [?25ldone
[?25hBuilding wheels for collected packages: GPUtil
  Building wheel for GPUtil (setup.py) ... [?25ldone
[?25h  Created wheel for GPUtil: filename=GPUtil-1.4.0-py3-none-any.whl size=7409 sha256=7fae8ce7ccfdf8f138d4781c76ee193828d99709af9887ac69ac102364445133
  Stored in directory: /root/.cache/pip/wheels/a9/8a/bd/81082387151853ab8b6b3ef33426e98f5cbfebc3c397a9d4d0
Successfully built GPUtil
Installing collected packages: GPUtil
Successfully installed GPUtil-1.4.0
[0m

In [2]:
import json
import math
import argparse
import torch
import logging
import numpy as np
from tqdm import tqdm
from pathlib import Path
from os import path
from transformers import AutoTokenizer, AutoModelWithLMHead
import copy
from operator import attrgetter
from typing import Dict, List, Optional, Tuple, Set, Union, Iterable
import collections
from torch import Tensor
from torch.nn import functional as F
from scipy.stats import rankdata
import gc 
from GPUtil import showUtilization as gpu_usage
from numba import cuda


In [3]:
# Utils
def tokenize_constraints(tokenizer, raw_cts):
    def tokenize(phrase):
        tokens = tokenizer.tokenize(phrase)
        token_ids = tokenizer.convert_tokens_to_ids(tokens)
        return token_ids, True
    return [[list(map(tokenize, clause)) for clause in ct] for ct in raw_cts]

def report_gpu(): 
    print(torch.cuda.list_gpu_processes())
    gc.collect() 
    torch.cuda.empty_cache()
    
def free_gpu_cache():
#     print("Initial GPU Usage")
    
#     gpu_usage()                             
    gc.collect() 
    torch.cuda.empty_cache()

#     cuda.select_device(0)
#     cuda.close()
#     cuda.select_device(0)

#     print("GPU Usage after emptying the cache")
#     gpu_usage()

In [4]:
# lexical constraints
logger = logging.getLogger(__name__)

Phrase = List[int]
Literal = Tuple[Phrase, bool]
# Represents a list of raw constraints for a sentence. Each constraint is a list of target-word IDs.
RawConstraintList = List[Phrase]
ClauseConstraintList = List[List[Literal]]


class Trie:
    """
    Represents a set of phrasal constraints for an input sentence.
    These are organized into a trie.
    """
    def __init__(self,
                 raw_phrases: Optional[RawConstraintList] = None,
                 parent_arc: int = None,
                 parent_trie: 'Trie' = None) -> None:
        self.final_ids = set()  # type: Set[int]
        self.children = {}  # type: Dict[int,'Trie']
        self.parent_arc = parent_arc
        self.parent_trie = parent_trie

        if raw_phrases:
            for phrase in raw_phrases:
                self.add_phrase(phrase)

    def add_phrase(self,
                   phrase: List[int]) -> None:
        """
        Recursively adds a phrase to this trie node.

        :param phrase: A list of word IDs to add to this trie node.
        """
        if len(phrase) == 1:
            self.final_ids.add(phrase[0])
        else:
            next_word = phrase[0]
            if next_word not in self.children:
                self.children[next_word] = Trie(parent_arc=next_word, parent_trie=self)
            self.step(next_word).add_phrase(phrase[1:])

    def delete_phrase(self,
                      phrase: List[int]) -> None:
        """
        Recursively deletes a phrase to this trie node.

        :param phrase: A list of word IDs to delete in this trie node.
        """
        if len(phrase) == 1:
            assert phrase[0] in self.final_ids, f"Trie {str(self)} \nDo not contain {phrase}"
            self.final_ids.remove(phrase[0])
        else:
            next_word = phrase[0]
            assert next_word in self.children.keys(), f"Trie {str(self)} \nDo not contain {phrase}"
            self.step(next_word).delete_phrase(phrase[1:])

        # Move the arc to an empty node to final_ids of its parent
        for arc in list(self.children):
            if len(self.children[arc]) == 0:
                self.children.pop(arc)

    def check_phrase(self,
                     phrase: List[int]) -> bool:
        """
        Check whether a phrase is in this trie.

        :param phrase: A list of word IDs to check existence.
        """
        if len(phrase) == 1:
            return phrase[0] in self.final_ids
        else:
            next_word = phrase[0]
            if next_word in self.children:
                return self.step(next_word).check_phrase(phrase[1:])
            return False

    def trace_phrase(self,
                     word_id: int) -> List[int]:
        """
        Recursively backward to get word ids in a phrase.

        :param word_id: The last word IDs in phrase.
        """
        assert word_id in self.final_ids, f"{word_id} does not in trie node {self.final_ids}"
        phrase = self.trace_arcs()
        phrase.append(word_id)
        return phrase

    def trace_arcs(self,) -> List[int]:
        """
        Recursively backward to get arc to ancestor
        """
        arcs = []
        parent_trie, parent_arc = self.parent_trie, self.parent_arc
        while parent_trie is not None:
            arcs.append(parent_arc)
            parent_arc = parent_trie.parent_arc
            parent_trie = parent_trie.parent_trie
        arcs.reverse()
        return arcs

    def __str__(self) -> str:
        s = f'({list(self.final_ids)}'
        for child_id in self.children.keys():
            s += f' -> {child_id} {self.children[child_id]}'
        s += ')'
        return s

    def __len__(self) -> int:
        """
        Returns the number of phrases represented in the trie.
        """
        phrase_count = len(self.final_ids)
        for child in self.children.values():
            phrase_count += len(child)
        return phrase_count

    def step(self, word_id: int) -> Optional['Trie']:
        """
        Returns the child node along the requested arc.

        :param word_id: requested arc.
        :return: The child node along the requested arc, or None if no such arc exists.
        """
        return self.children.get(word_id, None)

    def descend(self,
              arcs: List[int]) -> Optional['Trie']:
        pointer = self
        for arc in arcs:
            if pointer is None:
                break
            pointer = pointer.step(word_id=arc)
        return pointer

    def final(self) -> Set[int]:
        """
        Returns the set of final ids at this node.

        :return: The set of word IDs that end a constraint at this state.
        """
        return self.final_ids


class NegativeState:
    """
    Represents the state of a hypothesis in the AvoidTrie.
    The offset is used to return actual positions in the one-dimensionally-resized array that
    get set to infinity.

    :param avoid_trie: The trie containing the phrases to avoid.
    :param state: The current state (defaults to root).
    """
    def __init__(self,
                 avoid_trie: Trie,
                 state: List[Trie] = None) -> None:

        self.root = avoid_trie
        self.state = state if state else [self.root]

    def consume(self, word_id: int) -> 'NegativeState':
        """
        Consumes a word, and updates the state based on it. Returns new objects on a state change.

        The next state for a word can be tricky. Here are the cases:
        (1) If the word is found in our set of outgoing child arcs, we take that transition.
        (2) If the word is not found, and we are not in the root state, we need to reset.
            This means we pretend we were in the root state, and see if we can take a step
        (3) Otherwise, if we are not already in the root state (i.e., we were partially through
            the trie), we need to create a new object whose state is the root state
        (4) Finally, if we couldn't advance and were already in the root state, we can reuse
            this object.

        :param word_id: The word that was just generated.
        """
        new_state = []
        for state in set(self.state + [self.root]):
            if word_id in state.children:
                new_state.append(state.step(word_id))

        if new_state:
            return NegativeState(self.root, new_state)
        else:
            if len(self.state) == 1 and self.root == self.state[0]:
                return self
            else:
                return NegativeState(self.root, [self.root])

    def avoid(self) -> Set[int]:
        """
        Returns a set of word IDs that should be avoided. This includes the set of final states from the
        root node, which are single tokens that must never be generated.

        :return: A set of integers representing words that must not be generated next by this hypothesis.
        """
        return self.root.final().union(*[state.final() for state in self.state])

    def __str__(self) -> str:
        return str(self.state)


class NegativeBatch:
    """
    Represents a set of phrasal constraints for all items in the batch.
    For each hypotheses, there is an AvoidTrie tracking its state.

    :param beam_size: The beam size.
    :param avoid_list: The list of lists (raw phrasal constraints as IDs, one for each item in the batch).
    """
    def __init__(self,
                 beam_size: int,
                 avoid_list: Optional[List[RawConstraintList]] = None) -> None:

        self.avoid_states = []  # type: List[NegativeState]

        # Store the sentence-level tries for each item in their portions of the beam
        if avoid_list is not None:
            for literal_phrases in avoid_list:
                self.avoid_states += [NegativeState(Trie(literal_phrases))] * beam_size

    def reorder(self, indices: torch.Tensor) -> None:
        """
        Reorders the avoid list according to the selected row indices.
        This can produce duplicates, but this is fixed if state changes occur in consume().

        :param indices: An mx.nd.NDArray containing indices of hypotheses to select.
        """
        if self.avoid_states:
            self.avoid_states = [self.avoid_states[x] for x in indices.numpy()]

    def consume(self, word_ids: torch.Tensor) -> None:
        """
        Consumes a word for each trie, updating respective states.

        :param word_ids: The set of word IDs.
        """
        word_ids = word_ids.numpy().tolist()
        for i, word_id in enumerate(word_ids):
            if self.avoid_states:
                self.avoid_states[i] = self.avoid_states[i].consume(word_id)

    def avoid(self) -> Tuple[Tuple[int], Tuple[int]]:
        """
        Assembles a list of per-hypothesis words to avoid. The indices are (x, y) pairs into the scores
        array, which has dimensions (beam_size, target_vocab_size). These values are then used by the caller
        to set these items to np.inf so they won't be selected. Words to be avoided are selected by
        consulting both the global trie of phrases and the sentence-specific one.

        :return: Two lists of indices: the x coordinates and y coordinates.
        """
        to_avoid = set()  # type: Set[Tuple[int, int]]
        for i, state in enumerate(self.avoid_states):
            for word_id in state.avoid():
                to_avoid.add((i, word_id))

        return tuple(zip(*to_avoid))  # type: ignore


class PositiveState:
    """
    Represents a set of words and phrases that must appear in the output.
    The offset is used to return actual positions in the one-dimensionally-resized array that
    get set to infinity.

    :param positive_trie: The trie containing the phrases to appear.
    :param state: The current state (defaults to root).
    """
    def __init__(self,
                 positive_trie: Trie,
                 state: List[Trie] = None,
                 met_phrases: RawConstraintList = None) -> None:

        self.root = positive_trie
        self.state = state if state else [self.root]
        self.met_phrases = met_phrases

    def __str__(self):
        s = f'Root: {self.root}\nState: ['
        for state in self.state:
            s += f'{state}, '
        s += f']\nMet_phrases: {self.met_phrases}'
        return s

    def allowed(self) -> Set[int]:
        """
        Returns the set of constrained words that could follow this one.
        For unfinished phrasal constraints, it is the next word in the phrase.
        In other cases, it is the list of all unmet constraints.
        If all constraints are met, an empty set is returned.

        :return: The ID of the next required word, or -1 if any word can follow
        """
        allow = self.root.final().union(*[state.final() for state in self.state])
        allow |= set(self.root.children.keys()).union(*[set(state.children.keys()) for state in self.state])
        return allow

    def advance(self, word_id: int) -> 'PositiveState':
        """
        Updates the constraints object based on advancing on word_id.
        There is a complication, in that we may have started but not
        yet completed a multi-word constraint.  We need to allow constraints
        to be added as unconstrained words, so if the next word is
        invalid, we must "back out" of the current (incomplete) phrase,
        re-setting all of its words as unmet.

        :param word_id: The word ID to advance on.
        :return: A deep copy of the object, advanced on word_id.
        """
        new_state, met_phrases = [], []
        for state in set(self.state + [self.root]):
            if word_id in state.children:
                new_state.append(state.step(word_id))
            if word_id in state.final_ids:
                met_phrases.append(state.trace_phrase(word_id))

        if new_state:
            return PositiveState(self.root, new_state, met_phrases if met_phrases else None)
        else:
            if len(self.state) == 1 and self.root == self.state[0] and not met_phrases:
                return self
            else:
                return PositiveState(self.root, [self.root], met_phrases if met_phrases else None)


class Clause:
    """
    Object used to hold clause.

    :param idx: The id of this clause.
    :param positive: The positive constraints in this clause.
    :param negative: The soft negative constraints in this clause.
    :param satisfy: whether this clause is satisfied
    """

    __slots__ = ('idx', 'positive', 'negative', 'satisfy')

    def __init__(self,
                 idx: int,
                 positive: List[Phrase],
                 negative: List[Phrase],
                 satisfy: float) -> None:
        self.idx = idx
        self.positive = positive
        self.negative = negative
        self.satisfy = satisfy

    def __str__(self):
        return f'clause(id={self.idx}, positive={self.positive}, negative={self.negative}, satisfy={self.satisfy})'


def is_prefix(pref: List[int],
              phrase: List[int]):
    if not pref:
        return False
    return pref == phrase[:len(pref)]


class ConstrainedHypothesis:
    """
    Keep track of positive and negative constraint

    hard negative constraint will not be generated in any cases
    soft negative constraint might be generated in some case due to OR gate in clause
    positive constraints will be encourage to appear

    :param constraint_list: A list of clause constraints (each represented as a list of literals).
    """
    def __init__(self,
                 constraint_list: ClauseConstraintList,
                 eos_id: Union[int, list]
                 ) -> None:
        self.eos_id = eos_id if isinstance(eos_id, list) else [eos_id]
        self.clauses = []  # type: List[Clause]

        hard_neg_pool, soft_neg_pool, pos_pool = [], [], []  # type: RawConstraintList
        for idx, clause in enumerate(constraint_list):
            if not clause:
                continue
            pos_phrases, neg_phrases = [l[0] for l in clause if l[1]], [l[0] for l in clause if not l[1]]
            # clause contains single negative literal
            if not pos_phrases and len(neg_phrases) == 1:
                hard_neg_pool.extend(neg_phrases)
                #self.clauses.append(Clause(idx=idx, positive=[], negative=neg_phrases, satisfy=True))
            # clause contains multiple negative literals or both negative and positive literals
            elif neg_phrases:
                soft_neg_pool.extend(neg_phrases)
                self.clauses.append(Clause(idx=idx, positive=pos_phrases, negative=neg_phrases, satisfy=True))
            # clause contains only positive literals
            elif pos_phrases and not neg_phrases:
                pos_pool.extend(pos_phrases)
                self.clauses.append(Clause(idx=idx, positive=pos_phrases, negative=[], satisfy=False))
            else:
                import ipdb
                ipdb.set_trace()
                raise ValueError(f'Invalid state {clause}, should not be reached')

        self.hard_negative_state = NegativeState(Trie(hard_neg_pool)) if hard_neg_pool else None
        self.soft_negative_state = NegativeState(Trie(soft_neg_pool)) if soft_neg_pool else None
        self.positive_state = PositiveState(Trie(pos_pool)) if pos_pool else None

        self.orders = []
        self.in_process = None
        self.max_process = 0

    def __len__(self) -> int:
        """
        :return: The number of constraints.
        """
        return len(self.clauses)

    def __str__(self) -> str:
        return '\n'.join([str(c) for c in self.clauses])

    def size(self) -> int:
        """
        :return: the number of constraints
        """
        return len(self.clauses)

    def num_met(self) -> int:
        """
        :return: the number of constraints that have been met.
        """
        if not self.clauses:
            return 0
        return sum([int(c.satisfy) for c in self.clauses])

    def met_order(self) -> tuple:
        """
        :return: the number of constraints that have been met.
        """
        return tuple(sorted(self.orders))

    def clause_in_process(self) -> tuple:
        """
        :return: the index of clause that's in generation.
        """
        return tuple(self.in_process)

    def num_needed(self) -> int:
        """
        :return: the number of un-met constraints.
        """
        return self.size() - self.num_met()

    def finished(self) -> bool:
        """
        Return true if all the constraints have been met.

        :return: True if all the constraints are met.
        """
        return self.num_needed() == 0

    def is_valid(self, wordid: int) -> bool:
        """
        Ensures </s> is only generated when the hypothesis is completed.

        :param wordid: The wordid to validate.
        :return: True if all constraints are already met or the word ID is not the EOS id.
        """
        return self.finished() or wordid not in self.eos_id

    def avoid(self) -> Set[int]:
        banned = self.hard_negative_state.avoid() if self.hard_negative_state is not None else set()
        return banned

    def eos(self) -> list:
        """
        :return: Return EOS id.
        """
        return self.eos_id

    def advance(self, word_id: int) -> 'ConstrainedHypothesis':
        """
        Updates the constraints object based on advancing on word_id.
        If one of literals in a clause is satisfied, we mark this clause as satisfied

        :param word_id: The word ID to advance on.
        """
        obj = copy.deepcopy(self)

        if obj.soft_negative_state is not None:
            raise NotImplementedError

        if obj.hard_negative_state is not None:
            obj.hard_negative_state = obj.hard_negative_state.consume(word_id)

        if obj.positive_state is not None:
            temp_pos_state = obj.positive_state.advance(word_id)
            if temp_pos_state.met_phrases is not None:
                # get newly satisfied positive literals
                phrases_to_delete = []
                newly_met_clause = set()
                for phrase in temp_pos_state.met_phrases:
                    for clause in obj.clauses:
                        if not clause.satisfy and phrase in clause.positive:
                            phrases_to_delete.extend(clause.positive)
                            clause.satisfy = True
                            assert clause.idx not in obj.orders, 'clause has already satisfied, impossible state'
                            newly_met_clause.add(clause.idx)
                obj.orders.extend(sorted(newly_met_clause))

                # delete newly satisfied literals from positive trie state
                new_root = copy.deepcopy(temp_pos_state.root)
                phrases_to_delete = [list(i) for i in set(map(tuple, phrases_to_delete))]
                for phrase in phrases_to_delete:
                    if new_root.check_phrase(phrase):
                        new_root.delete_phrase(phrase)
                new_trie_states = set()
                for state in temp_pos_state.state:
                    # pointer at root state
                    if state.parent_trie is None:
                        new_trie_states.add(new_root)
                    else:
                        trace = state.trace_arcs()
                        new_state = new_root.descend(trace)
                        if new_state is not None:
                            new_trie_states.add(new_state)
                obj.positive_state = PositiveState(positive_trie=new_root, state=list(new_trie_states))
            else:
                obj.positive_state = temp_pos_state

            history = [s.trace_arcs() for s in obj.positive_state.state]
            newly_in_process = set()
            max_process = 0
            for phrase in history:
                for clause in obj.clauses:
                    phrase_in_process = [c for c in clause.positive if is_prefix(phrase, c)]
                    if not clause.satisfy and bool(phrase_in_process):
                        process_portion = len(phrase) / min([len(x) for x in phrase_in_process])
                        max_process = max(max_process, process_portion)
                        assert clause.idx not in obj.orders, 'clause has already satisfied, impossible state'
                        newly_in_process.add(clause.idx)
            obj.in_process = sorted(newly_in_process)
            obj.max_process = max_process
        return obj


def init_batch(raw_constraints: List[ClauseConstraintList],
               beam_size: int,
               eos_id: Union[int, list]) -> List[Optional[ConstrainedHypothesis]]:
    """
    :param raw_constraints: The list of clause constraints.
    :param beam_size: The beam size.
    :param eos_id: The target-language vocabulary ID of the EOS symbol.
    :return: A list of ConstrainedHypothesis objects (shape: (batch_size * beam_size,)).
    """
    constraints_list = [None] * (len(raw_constraints) * beam_size)  # type: List[Optional[ConstrainedHypothesis]]
    for i, raw_list in enumerate(raw_constraints):
        hyp = ConstrainedHypothesis(raw_list, eos_id)
        idx = i * beam_size
        constraints_list[idx:idx + beam_size] = [copy.deepcopy(hyp) for _ in range(beam_size)]
    return constraints_list


class ConstrainedCandidate:
    """
    Object used to hold candidates for the beam in topk().

    :param row: The row in the scores matrix.
    :param col: The column (word ID) in the scores matrix.
    :param score: the associated accumulated score.
    :param hypothesis: The ConstrainedHypothesis containing information about met constraints.
    """

    __slots__ = ('row', 'col', 'score', 'hypothesis', 'rank')

    def __init__(self,
                 row: int,
                 col: int,
                 score: float,
                 hypothesis: ConstrainedHypothesis,
                 rank: float = None,) -> None:
        self.row = row
        self.col = col
        self.score = score
        self.hypothesis = hypothesis
        self.rank = rank

    def __hash__(self):
        return hash((self.row, self.col))

    def __eq__(self, other):
        return self.row == other.row and self.col == other.col

    def __str__(self):
        return '({}, {}, {}, {})'.format(self.row, self.col, self.score, self.hypothesis.num_met())


if __name__ == '__main__':
    clauses = [[[([3, 4, 5], True), ([3, 4], True), ([4, 5], True)], [([3, 4], True), ([6], True), ([7], True)]],
               [[([6], True), ([6, 7], True), ([6, 7, 8], True)], [([6, 9], True), ([6, 4, 9], True)]],
               [[([3, 4, 5], True)], [([3, 4], True)], [([4, 5], True)]],
               [[([3, 4], True)], [([2, 3, 5], True)], [([6, 5], True)]]]

    constraints = init_batch(raw_constraints=clauses,
                             beam_size=1,
                             eos_id=0)

    constraint = constraints[2]
    print(constraint)
    print(constraints)
    print()
    for w in [2, 3, 4, 5]:
        constraint = constraint.advance(w)
        print(constraint)
        print(constraint.positive_state)
        print(constraint.positive_state.allowed())
        print(constraint.met_order())
        print(constraint.clause_in_process())
        print()


clause(id=0, positive=[[3, 4, 5]], negative=[], satisfy=False)
clause(id=1, positive=[[3, 4]], negative=[], satisfy=False)
clause(id=2, positive=[[4, 5]], negative=[], satisfy=False)
[<__main__.ConstrainedHypothesis object at 0x795235ea0640>, <__main__.ConstrainedHypothesis object at 0x795235ea9db0>, <__main__.ConstrainedHypothesis object at 0x795235eaa1d0>, <__main__.ConstrainedHypothesis object at 0x795235eaa7a0>]

clause(id=0, positive=[[3, 4, 5]], negative=[], satisfy=False)
clause(id=1, positive=[[3, 4]], negative=[], satisfy=False)
clause(id=2, positive=[[4, 5]], negative=[], satisfy=False)
Root: ([] -> 3 ([4] -> 4 ([5])) -> 4 ([5]))
State: [([] -> 3 ([4] -> 4 ([5])) -> 4 ([5])), ]
Met_phrases: None
{3, 4}
()
()

clause(id=0, positive=[[3, 4, 5]], negative=[], satisfy=False)
clause(id=1, positive=[[3, 4]], negative=[], satisfy=False)
clause(id=2, positive=[[4, 5]], negative=[], satisfy=False)
Root: ([] -> 3 ([4] -> 4 ([5])) -> 4 ([5]))
State: [([4] -> 4 ([5])), ]
Met_phrases: Non

In [5]:
# top k
NEGATIVE_INF = -1000

def topk_huggingface(timestep: int,
                     batch_size: int,
                     beam_size: int,
                     vocab_size: int,
                     pad_token_id: int,
                     prune_factor: int,
                     sat_tolerance: int,
                     beta: float,
                     inactive: np.array,
                     scores: np.array,
                     hypotheses: List[ConstrainedHypothesis],
                     num_fill: int,
                     early_stop: float = None) -> Tuple[np.array, np.array,
                                                        List[List[Union[ConstrainedHypothesis, None]]],
                                                        List[List[int]]]:
    """
    Builds a new topk list such that the beam contains hypotheses having completed different numbers of constraints.
    These items are built from three different types: (1) the best items across the whole
    scores matrix, (2) the set of words that must follow existing constraints, and (3) k-best items from each row.

    :param batch_size: The number of segments in the batch.
    :param beam_size: The length of the beam for each segment.
    :param vocab_size: The size of vocabulary.
    :param pad_token_id:
    :param lambda_1:
    :param sat_tolerance:
    :param inactive: Array listing inactive rows (shape: (batch_size, beam_size,)).
    :param scores: The scores array (shape: (batch_size, beam_size * target_vocab_size)).
    :param hypotheses: The list of hypothesis objects. (length: (batch_size * beam_size,))
    :param num_mets: The list of int how many constraints satisfied. (length: (batch_size * beam_size,))
    :param num_fill: The number of required return beam
    :return: A tuple containing the best hypothesis rows, the best hypothesis words, the scores,
        the updated constrained hypotheses, and the updated set of inactive hypotheses.
    """

    seq_scores, raw_token_idx = torch.topk(scores, beam_size, dim=1, largest=True, sorted=True)
    best_ids = (raw_token_idx // vocab_size).cpu().numpy()
    best_word_ids = (raw_token_idx % vocab_size).cpu().numpy()
    seq_scores = seq_scores.cpu().detach().numpy()

    scores = torch.reshape(scores, [batch_size, beam_size, -1]).cpu().detach().numpy()

    select_best_ids = np.ones((batch_size, num_fill)) * -1
    select_best_word_ids = np.ones((batch_size, num_fill)) * -1
    select_seq_scores = np.zeros((batch_size, num_fill))
    select_hypotheses = [[None] * num_fill for _ in range(batch_size)]
    select_num_mets = [[-1] * num_fill for _ in range(batch_size)]

    for sentno in range(batch_size):
        rows = slice(sentno * beam_size, sentno * beam_size + beam_size)
        if all([x is None for x in hypotheses[rows]]):
            select_best_ids[sentno] = [0] * num_fill
            select_best_word_ids[sentno] = [pad_token_id] * num_fill
            select_seq_scores[sentno] = [0] * num_fill
            select_hypotheses[sentno] = [None] * num_fill
            select_num_mets[sentno] = [-1] * num_fill
            continue

        assert not any([x is None for x in hypotheses[rows]]), 'Bad state'

        select_best_ids[sentno], select_best_word_ids[sentno], select_seq_scores[sentno],\
            select_hypotheses[sentno], select_num_mets[sentno] = _sequential_topk(timestep,
                                                                                  beam_size,
                                                                                  prune_factor,
                                                                                  sat_tolerance,
                                                                                  beta,
                                                                                  inactive[sentno],
                                                                                  scores[sentno],
                                                                                  hypotheses[rows],
                                                                                  best_ids[sentno],
                                                                                  best_word_ids[sentno],
                                                                                  seq_scores[sentno],
                                                                                  num_fill=num_fill,
                                                                                  early_stop=early_stop)

    select_raw_token_idx = select_best_ids * vocab_size + select_best_word_ids
    return select_seq_scores, select_raw_token_idx, select_hypotheses, select_num_mets


def _sequential_topk(timestep: int,
                     beam_size: int,
                     prune_factor: int,
                     sat_tolerance: int,
                     beta: float,
                     inactive: np.array,
                     scores: np.array,
                     hypotheses: List[ConstrainedHypothesis],
                     best_ids: np.array,
                     best_word_ids: np.array,
                     sequence_scores: np.array,
                     num_fill: int = None,
                     early_stop: float = None) -> Tuple[np.array, np.array, np.array,
                                                        List[ConstrainedHypothesis], List[int]]:
    """
    Builds a new topk list such that the beam contains hypotheses having completed different numbers of constraints.
    These items are built from three different types: (1) the best items across the whole
    scores matrix, (2) the set of words that must follow existing constraints, and (3) k-best items from each row.

    :param timestep: The current decoder timestep.
    :param beam_size: The length of the beam for each segment.
    :param inactive: Array listing inactive rows (shape: (beam_size,)).
    :param scores: The scores array (shape: (beam_size, target_vocab_size)).
    :param hypotheses: The list of hypothesis objects. (length: (beam_size,))
    :param best_ids: The current list of best hypotheses (shape: (beam_size,)).
    :param best_word_ids: The parallel list of best word IDs (shape: (beam_size,)).
    :param sequence_scores: (shape: (beam_size, 1)).
    :return: A tuple containing the best hypothesis rows, the best hypothesis words, the scores,
        the updated constrained hypotheses, and the updated set of inactive hypotheses.
    """

    candidates = set()
    finished_candidates = set()
    # the best item (constrained or not) in that row
    best_next = np.argmax(scores, axis=1)
    rank = rankdata(-1 * scores, method='dense').reshape(scores.shape)

    # (1) Add all of the top-k items (which were passed) in as long as they pass the constraints
    for row, col, seq_score in zip(best_ids, best_word_ids, sequence_scores):
        row, col = int(row), int(col)
        seq_score = float(seq_score)
        new_item = hypotheses[row].advance(col)
        cand = ConstrainedCandidate(row, col, seq_score, new_item)
        if cand.hypothesis.finished():
            finished_candidates.add(cand)
        elif hypotheses[row].is_valid(col) or int(best_next[row]) == col:
            candidates.add(cand)

    hit = np.stack([best_ids, best_word_ids], axis=1).tolist()
    # For each hypothesis, we add (2) all the constraints that could follow it and
    # (3) the best item (constrained or not) in that row
    for row in range(beam_size):
        if inactive[row] or hypotheses[row]:
            continue

        hyp = hypotheses[row]
        
        if hyp.positive_state is None:
            inactive[row] = True
            continue
        print(f"positive state:{hyp.positive_state.allowed()}")

        # (2) add all the constraints that could extend this
        nextones = hyp.positive_state.allowed()

        # (3) add the best items (if it's valid)
        best_k = np.argsort(scores[row])[::-1][:beam_size]
        for col in best_k:
            if hyp.is_valid(col):
                nextones.add(col)

        # Now, create new candidates for each of these items
        for col in nextones:
            if [row, col] not in hit and (rank[row, col] < prune_factor and scores[row, col] > NEGATIVE_INF):
                new_item = hyp.advance(col)
                score = scores[row, col]
                cand = ConstrainedCandidate(row, col, score, new_item)
                if cand.hypothesis.finished() and col in cand.hypothesis.eos():
                    finished_candidates.add(cand)
                else:
                    candidates.add(cand)

        # Add finished candidates in finished set:
        if hyp.finished():
            best_k = np.argsort(scores[row])[::-1][:int(beam_size * early_stop)]
            for col in best_k:
                if col in hyp.eos() and scores[row, col] > NEGATIVE_INF:
                    new_item = hyp.advance(col)
                    score = scores[row, col]
                    cand = ConstrainedCandidate(row, col, score, new_item)
                    finished_candidates.add(cand)

    if num_fill is not None:
        assert num_fill > beam_size, "at least select number of beam candidates"
    else:
        num_fill = beam_size

    # all the sentences finish without satisfy all constraints
    if (not candidates) and (not finished_candidates):
        print('edge case')
        for row, col, seq_score in zip(best_ids, best_word_ids, sequence_scores):
            row, col = int(row), int(col)
            seq_score = float(seq_score)
            new_item = hypotheses[row].advance(col)
            cand = ConstrainedCandidate(row, col, seq_score, new_item)
            candidates.add(cand)

    chunk_candidates = []
    if candidates:
        # Sort the candidates.
        sorted_candidates = sorted(candidates, key=attrgetter('score'), reverse=True)
        max_satisfy = max([x.hypothesis.num_met() for x in sorted_candidates])
        sorted_candidates = [x for x in sorted_candidates if x.hypothesis.num_met() >= max_satisfy - sat_tolerance]

        for cand in sorted_candidates:
            cand.rank = cand.score / (timestep + 1)
            if cand.hypothesis.max_process:
                cand.rank -= beta * math.log(cand.hypothesis.max_process)
        sorted_candidates = sorted(sorted_candidates, key=attrgetter('rank'), reverse=True)

        # Bucket candidates in each group by met order
        all_orders = set([x.hypothesis.met_order() for x in sorted_candidates])
        grouped_order_candidates = [[x for x in sorted_candidates if x.hypothesis.met_order() == o] for o in all_orders]

        # Group the top_i candidate of each group in chunk
        chunk_candidates = []
        num_chunk = max([len(x) for x in grouped_order_candidates])
        for i in range(num_chunk):
            chunk_i = []
            for g in grouped_order_candidates:
                if len(g) > i:
                    chunk_i.append(g[i])
            chunk_candidates.append(chunk_i)
        # Sort candidates in each chunk by score
        chunk_candidates = [sorted(x, key=attrgetter('rank'), reverse=True) for x in chunk_candidates]

    pruned_candidates = sorted(finished_candidates, key=attrgetter('score'), reverse=True)[:(num_fill if not candidates else beam_size)]
    num_finish = len(pruned_candidates)
    for chunk in chunk_candidates:
        if len(pruned_candidates) >= num_fill:
            break

        chunk = [x for x in chunk if x not in pruned_candidates]
        if not chunk:
            continue

        pruned_candidates.extend(chunk[:num_fill - len(pruned_candidates)])

    if num_fill > beam_size:
        if candidates:
            select_num = num_finish + beam_size
            complete_candidates = sorted(pruned_candidates[:num_finish], key=attrgetter('score'), reverse=True)
            include_candidates = sorted(pruned_candidates[num_finish:select_num], key=attrgetter('score'), reverse=True)
            extra_candidates = sorted(pruned_candidates[select_num:], key=attrgetter('score'), reverse=True)
            pruned_candidates = complete_candidates + include_candidates + extra_candidates
    else:
        pruned_candidates = sorted(pruned_candidates, key=attrgetter('score'), reverse=True)

    num_pruned_candidates = len(pruned_candidates)

    inactive = np.zeros(num_fill)
    inactive[:num_pruned_candidates] = 0

    # Pad the beam so array assignment still works
    if num_pruned_candidates < num_fill:
        inactive[num_pruned_candidates:] = 1
        pruned_candidates += [pruned_candidates[num_pruned_candidates - 1]] * (num_fill - num_pruned_candidates)

    assert len(pruned_candidates) == num_fill, 'candidates number mismatch'

    return (np.array([x.row for x in pruned_candidates]),
            np.array([x.col for x in pruned_candidates]),
            np.array([x.score for x in pruned_candidates]),
            [x.hypothesis for x in pruned_candidates],
            [x.hypothesis.num_met() for x in pruned_candidates])


In [6]:
# generate
logger = logging.getLogger(__name__)


@torch.no_grad()
def postprocess_next_token_scores(
        self,
        scores,
        input_ids,
        no_repeat_ngram_size,
        bad_words_ids,
        cur_len,
        min_length,
        max_length,
        eos_token_id,
        repetition_penalty,
        batch_size,
        num_beams,
):
    # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
    if repetition_penalty != 1.0:
        self.enforce_repetition_penalty_(
            scores, batch_size, num_beams, input_ids, repetition_penalty,
        )

    # set eos token prob to zero if min_length is not reached
    if eos_token_id is not None and cur_len < min_length:
        scores[:, eos_token_id] = -float("inf")

    if no_repeat_ngram_size > 0:
        # calculate a list of banned tokens to prevent repetitively generating the same ngrams
        num_batch_hypotheses = batch_size * num_beams
        # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
        banned_batch_tokens = calc_banned_ngram_tokens(
            input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
        )
        for i, banned_tokens in enumerate(banned_batch_tokens):
            scores[i, banned_tokens] = -float("inf")

    if bad_words_ids is not None:
        # calculate a list of banned tokens according to bad words
        banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)

        for i, banned_tokens in enumerate(banned_tokens):
            scores[i, banned_tokens] = -float("inf")

    return scores
def generate(
    self,
    input_ids: Optional[torch.LongTensor] = None,
    max_length: Optional[int] = None,
    min_length: Optional[int] = None,
    do_sample: Optional[bool] = None,
    early_stopping: Optional[bool] = None,
    num_beams: Optional[int] = None,
    temperature: Optional[float] = None,
    top_k: Optional[int] = None,
    top_p: Optional[float] = None,
    repetition_penalty: Optional[float] = None,
    bad_words_ids: Optional[Iterable[int]] = None,
    bos_token_id: Optional[int] = None,
    pad_token_id: Optional[int] = None,
    eos_token_id: Optional[int] = None,
    length_penalty: Optional[float] = None,
    no_repeat_ngram_size: Optional[int] = None,
    num_return_sequences: Optional[int] = None,
    attention_mask: Optional[torch.LongTensor] = None,
    decoder_start_token_id: Optional[int] = None,
    use_cache: Optional[bool] = None,
    constraints: Optional[List[Optional[ConstrainedHypothesis]]] = None,
    prune_factor: Optional[int] = None,
    sat_tolerance: Optional[int] = None,
    beta: Optional[int] = None,
    early_stop: Optional[float] = None,
    **model_specific_kwargs
) -> torch.LongTensor:
    r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.

    Adapted in part from `Facebook's XLM beam search code`_.

    .. _`Facebook's XLM beam search code`:
       https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529


    Parameters:

        input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)`
            The sequence used as a prompt for the generation. If `None` the method initializes
            it as an empty `torch.LongTensor` of shape `(1,)`.

        max_length: (`optional`) int
            The max length of the sequence to be generated.  Between `min_length` and infinity. Default to 20.

        min_length: (`optional`) int
            The min length of the sequence to be generated.  Between 0 and infinity. Default to 0.

        do_sample: (`optional`) bool
            If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.

        early_stopping: (`optional`) bool
            if set to `True` beam search is stopped when at least `num_beams` sentences finished per batch. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.

        num_beams: (`optional`) int
            Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.

        temperature: (`optional`) float
            The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.

        top_k: (`optional`) int
            The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.

        top_p: (`optional`) float
            The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.

        repetition_penalty: (`optional`) float
            The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.

        pad_token_id: (`optional`) int
            Padding token. Default to specicic model pad_token_id or None if it does not exist.

        bos_token_id: (`optional`) int
            BOS token. Defaults to `bos_token_id` as defined in the models config.

        eos_token_id: (`optional`) int
            EOS token. Defaults to `eos_token_id` as defined in the models config.

        length_penalty: (`optional`) float
            Exponential penalty to the length. Default to 1.

        no_repeat_ngram_size: (`optional`) int
            If set to int > 0, all ngrams of size `no_repeat_ngram_size` can only occur once.
        bad_words_ids: (`optional`) list of lists of int
            `bad_words_ids` contains tokens that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`.

        num_return_sequences: (`optional`) int
            The number of independently computed returned sequences for each element in the batch. Default to 1.

        attention_mask (`optional`) obj: `torch.LongTensor` of same shape as `input_ids`
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
            Defaults to `None`.

            `What are attention masks? <../glossary.html#attention-mask>`__

        decoder_start_token_id=None: (`optional`) int
            If an encoder-decoder model starts decoding with a different token than BOS.
            Defaults to `None` and is changed to `BOS` later.

        use_cache: (`optional`) bool
            If `use_cache` is True, past key values are used to speed up decoding if applicable to model. Defaults to `True`.

        model_specific_kwargs: (`optional`) dict
            Additional model specific kwargs will be forwarded to the `forward` function of the model.

    Return:

        output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`
            sequence_length is either equal to max_length or shorter if all batches finished early due to the `eos_token_id`

    """

    # We cannot generate if the model does not have a LM head
    if self.get_output_embeddings() is None:
        raise AttributeError(
            "You tried to generate sequences with a model that does not have a LM Head."
            "Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )"
        )

    max_length = max_length if max_length is not None else self.config.max_length
    min_length = min_length if min_length is not None else self.config.min_length
    do_sample = do_sample if do_sample is not None else self.config.do_sample
    early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
    use_cache = use_cache if use_cache is not None else False # self.config.use_cache
    num_beams = num_beams if num_beams is not None else self.config.num_beams
    temperature = temperature if temperature is not None else self.config.temperature
    top_k = top_k if top_k is not None else self.config.top_k
    top_p = top_p if top_p is not None else self.config.top_p
    repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
    bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
    pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
    eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
    length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
    no_repeat_ngram_size = (
        no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
    )
    bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
    num_return_sequences = (
        num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
    )
    decoder_start_token_id = (
        decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
    )

    if input_ids is not None:
        batch_size = input_ids.shape[0]  # overriden by the input batch_size
    else:
        batch_size = 1

    assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
    assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
    assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
    assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
    assert isinstance(use_cache, bool), "`use_cache` should be a boolean."
    assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
    assert temperature > 0, "`temperature` should be strictly positive."
    assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
    assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
    assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
    assert input_ids is not None or (
        isinstance(bos_token_id, int) and bos_token_id >= 0
    ), "If input_ids is not defined, `bos_token_id` should be a positive integer."
    assert pad_token_id is None or (
        isinstance(pad_token_id, int) and (pad_token_id >= 0)
    ), "`pad_token_id` should be a positive integer."
    assert (eos_token_id is None) or (
        isinstance(eos_token_id, int) and (eos_token_id >= 0)
    ), "`eos_token_id` should be a positive integer."
    assert length_penalty > 0, "`length_penalty` should be strictly positive."
    assert (
        isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
    ), "`no_repeat_ngram_size` should be a positive integer."
    assert (
        isinstance(num_return_sequences, int) and num_return_sequences > 0
    ), "`num_return_sequences` should be a strictly positive integer."
    assert (
        bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
    ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"

    if input_ids is None:
        assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
            "you should either supply a context to complete as `input_ids` input "
            "or a `bos_token_id` (integer >= 0) as a first token to start the generation."
        )
        input_ids = torch.full(
            (batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device,
        )
    else:
        assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."

    # not allow to duplicate outputs when greedy decoding
    if do_sample is False:
        if num_beams == 1:
            # no_beam_search greedy generation conditions
            assert (
                num_return_sequences == 1
            ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"

        else:
            # beam_search greedy generation conditions
            assert (
                num_beams >= num_return_sequences
            ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"

    # create attention mask if necessary
    if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
        attention_mask = input_ids.ne(pad_token_id).long()
    elif attention_mask is None:
        attention_mask = input_ids.new_ones(input_ids.shape)

    # set pad_token_id to eos_token_id if not set. Important that this is done after
    # attention_mask is created
    if pad_token_id is None and eos_token_id is not None:
        logger.warning(
            "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
        )
        pad_token_id = eos_token_id

    # current position and vocab size
    if hasattr(self.config, "vocab_size"):
        vocab_size = self.config.vocab_size
    elif (
        self.config.is_encoder_decoder
        and hasattr(self.config, "decoder")
        and hasattr(self.config.decoder, "vocab_size")
    ):
        vocab_size = self.config.decoder.vocab_size

    # set effective batch size and effective batch multiplier according to do_sample
    if do_sample:
        effective_batch_size = batch_size * num_return_sequences
        effective_batch_mult = num_return_sequences
    else:
        effective_batch_size = batch_size
        effective_batch_mult = 1

    if self.config.is_encoder_decoder:
        if decoder_start_token_id is None:
            decoder_start_token_id = bos_token_id

        assert (
            decoder_start_token_id is not None
        ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
        assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
        assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)

        # get encoder and store encoder outputs
        encoder = self.get_encoder()

        encoder_outputs: tuple = encoder(input_ids, attention_mask=attention_mask)

    # Expand input ids if num_beams > 1 or num_return_sequences > 1
    if num_return_sequences > 1 or num_beams > 1:
        input_ids_len = input_ids.shape[-1]
        input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
        attention_mask = attention_mask.unsqueeze(1).expand(
            batch_size, effective_batch_mult * num_beams, input_ids_len
        )

        input_ids = input_ids.contiguous().view(
            effective_batch_size * num_beams, input_ids_len
        )  # shape: (batch_size * num_return_sequences * num_beams, cur_len)
        attention_mask = attention_mask.contiguous().view(
            effective_batch_size * num_beams, input_ids_len
        )  # shape: (batch_size * num_return_sequences * num_beams, cur_len)

    if self.config.is_encoder_decoder:
        # create empty decoder_input_ids
        input_ids = torch.full(
            (effective_batch_size * num_beams, 1),
            decoder_start_token_id,
            dtype=torch.long,
            device=next(self.parameters()).device,
        )
        cur_len = 1

        assert (
            batch_size == encoder_outputs[0].shape[0]
        ), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} "

        # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
        expanded_batch_idxs = (
            torch.arange(batch_size)
            .view(-1, 1)
            .repeat(1, num_beams * effective_batch_mult)
            .view(-1)
            .to(input_ids.device)
        )
        # expand encoder_outputs
        encoder_outputs = (encoder_outputs[0].index_select(0, expanded_batch_idxs), *encoder_outputs[1:])

    else:
        encoder_outputs = None
        cur_len = input_ids.shape[-1]
    free_gpu_cache()
    if num_beams > 1:
        output = _generate_beam_search(
            self,
            input_ids=input_ids,
            cur_len=cur_len,
            max_length=max_length,
            min_length=min_length,
            do_sample=do_sample,
            early_stopping=early_stopping,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            no_repeat_ngram_size=no_repeat_ngram_size,
            bad_words_ids=bad_words_ids,
            bos_token_id=bos_token_id,
            pad_token_id=pad_token_id,
            decoder_start_token_id=decoder_start_token_id,
            eos_token_id=eos_token_id,
            batch_size=effective_batch_size,
            num_return_sequences=num_return_sequences,
            length_penalty=length_penalty,
            num_beams=num_beams,
            vocab_size=vocab_size,
            encoder_outputs=encoder_outputs,
            attention_mask=attention_mask,
            use_cache=use_cache,
            constraints=constraints,
            prune_factor=prune_factor,
            sat_tolerance=sat_tolerance,
            beta=beta,
            early_stop=early_stop,
            model_specific_kwargs=model_specific_kwargs,
        )
    else:
        raise NotImplementedError
    return output


class BeamHypotheses(object):
    def __init__(self, num_beams, max_length, length_penalty, early_stopping):
        """
        Initialize n-best list of hypotheses.
        """
        self.max_length = max_length - 1  # ignoring bos_token
        self.length_penalty = length_penalty
        self.early_stopping = early_stopping
        self.num_beams = num_beams * 2
        self.beams = []
        self.worst_score = 1e9

    def __len__(self):
        """
        Number of hypotheses in the list.
        """
        return len(self.beams)

    def add(self, hyp, sum_logprobs, num_met):
        """
        Add a new hypothesis to the list.
        """
        score = sum_logprobs / len(hyp) ** self.length_penalty
        #score = sum_logprobs / math.pow((5 + len(hyp) + 1) / 6.0, self.length_penalty)
        if len(self) < self.num_beams or score > self.worst_score:
            self.beams.append((score, hyp, num_met))
            if len(self) > self.num_beams:
                sorted_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
                del self.beams[sorted_scores[0][1]]
                self.worst_score = sorted_scores[1][0]
            else:
                self.worst_score = min(score, self.worst_score)

    def is_done(self, best_sum_logprobs, cur_len=None):
        """
        If there are enough hypotheses and that none of the hypotheses being generated
        can become better than the worst one in the heap, then we are done with this sentence.
        """

        if len(self) < self.num_beams:
            return False
        elif self.early_stopping:
            return True
        else:
            if cur_len is None:
                cur_len = self.max_length
            cur_score = best_sum_logprobs / cur_len ** self.length_penalty
            #cur_score = best_sum_logprobs / math.pow((5 + cur_len + 1) / 6.0, self.length_penalty)
            ret = self.worst_score >= cur_score
            return ret


def _generate_beam_search(
        self,
        input_ids,
        cur_len,
        max_length,
        min_length,
        do_sample,
        early_stopping,
        temperature,
        top_k,
        top_p,
        repetition_penalty,
        no_repeat_ngram_size,
        bad_words_ids,
        bos_token_id,
        pad_token_id,
        eos_token_id,
        decoder_start_token_id,
        batch_size,
        num_return_sequences,
        length_penalty,
        num_beams,
        vocab_size,
        encoder_outputs,
        attention_mask,
        use_cache,
        constraints,
        prune_factor,
        sat_tolerance,
        beta,
        early_stop,
        model_specific_kwargs,
):
    """ Generate sequences for each example with beam search.
    """
    # end condition
    cons_eos = constraints[0].eos()

    last_non_masked_idx = (torch.sum(attention_mask, dim=1) - 1).int()
    start_idx = (last_non_masked_idx).view(-1, 1).repeat(1, self.config.vocab_size).unsqueeze(1).long()

    init_length = cur_len
    position_ids = torch.tensor([list(range(init_length)) for i in range(input_ids.shape[0])])
    for i, position_ids_slice in enumerate(position_ids):
        position_ids_slice[last_non_masked_idx[i]:] = position_ids_slice[last_non_masked_idx[i]]
    position_ids = position_ids.to(input_ids.device)

    # generated hypotheses
    generated_hyps = [
        BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
        for _ in range(batch_size)
    ]

    # scores for each sentence in the beam
    beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)

    # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
    if do_sample is False:
        beam_scores[:, 1:] = -1e9
    beam_scores = beam_scores.view(-1)  # shape (batch_size * num_beams,)

    # cache compute states\
    past = (encoder_outputs, None) if encoder_outputs is not None else None

    # done sentences
    done = [False for _ in range(batch_size)]

    # init number of met clauses
    num_mets = [x.num_met() for x in constraints]

    while cur_len < max_length:
        model_inputs = self.prepare_inputs_for_generation(
            input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
        )
        model_inputs["attention_mask"] = attention_mask
        model_inputs["position_ids"] = position_ids[:, -1].unsqueeze(-1) if past else position_ids

        outputs = self(**model_inputs)  # (batch_size * num_beams, cur_len, vocab_size)
        if cur_len == init_length:
            next_token_logits = outputs[0].gather(1, start_idx).squeeze(1)
        else:
            next_token_logits = outputs[0][:, -1, :]  # (batch_size * num_beams, vocab_size)

        # if model has past, then set the past variable to speed up decoding
        # if self._use_cache(outputs, use_cache):
        #     past = outputs[1]

        # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
        if repetition_penalty != 1.0:
            self.enforce_repetition_penalty_(
                next_token_logits, batch_size, num_beams, input_ids, repetition_penalty,
            )

        if temperature != 1.0:
            next_token_logits = next_token_logits / temperature

        if self.config.is_encoder_decoder and do_sample is False:
            next_token_logits = self.adjust_logits_during_generation(
                next_token_logits, cur_len=cur_len, max_length=max_length
            )

        scores = F.log_softmax(next_token_logits, dim=-1)  # (batch_size * num_beams, vocab_size)
#         free_gpu_cache()
        scores = postprocess_next_token_scores(
            self=self,
            scores=scores,
            input_ids=input_ids,
            no_repeat_ngram_size=no_repeat_ngram_size,
            bad_words_ids=bad_words_ids,
            cur_len=cur_len,
            min_length=min_length,
            max_length=max_length,
            eos_token_id=eos_token_id,
            repetition_penalty=repetition_penalty,
            batch_size=batch_size,
            num_beams=num_beams,
        )

        assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
            scores.shape, (batch_size * num_beams, vocab_size)
        )

        avoid_idx = []
        for i, c in enumerate(constraints):
            if c is not None:
                avoid_idx.extend([[i, x] for x in c.avoid()])
                if cur_len - init_length < min_length:
                    avoid_idx.extend([[i, x] for x in c.eos()])
        if avoid_idx:
            banned_mask = torch.LongTensor(avoid_idx)
            indices = torch.ones(len(banned_mask))
            banned_mask = torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(
                scores.device).to_dense().bool()
            scores.masked_fill_(banned_mask, -float("inf"))

        if do_sample:
            raise NotImplementedError
        else:
            next_scores = scores + beam_scores[:, None].expand_as(scores)  # (batch_size * num_beams, vocab_size)

            # re-organize to group the beam together (we are keeping top hypothesis accross beams)
            full_scores = next_scores.view(
                batch_size, num_beams * vocab_size
            )  # (batch_size, num_beams * vocab_size)

            next_scores, next_tokens = torch.topk(full_scores, 2 * num_beams, dim=1, largest=True, sorted=True)

            pick_scores, pick_tokens, constraints, num_mets = topk_huggingface(timestep=cur_len,
                                                                               batch_size=batch_size,
                                                                               beam_size=num_beams,
                                                                               vocab_size=vocab_size,
                                                                               pad_token_id=pad_token_id,
                                                                               prune_factor=prune_factor,
                                                                               sat_tolerance=sat_tolerance,
                                                                               beta=beta,
                                                                               inactive=np.zeros((batch_size, num_beams)),
                                                                               scores=full_scores,
                                                                               hypotheses=constraints,
                                                                               num_fill=2 * num_beams,
                                                                               early_stop=early_stop)

            next_scores = torch.tensor(pick_scores, dtype=next_scores.dtype, device=next_scores.device)
            next_tokens = torch.tensor(pick_tokens, dtype=next_tokens.dtype, device=next_tokens.device)

        assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)

        # next batch beam content
        next_batch_beam = []

        # for each sentence
        for batch_idx in range(batch_size):

            # if we are done with this sentence, add a pad token
            if done[batch_idx]:
                assert (
                    len(generated_hyps[batch_idx]) >= num_beams
                ), "Batch can only be done if at least {} beams have been generated".format(num_beams)
                assert (
                    eos_token_id is not None and pad_token_id is not None
                ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
                next_batch_beam.extend([(0, pad_token_id, 0, None, -1)] * num_beams)  # pad the batch
                continue

            # next sentence beam content
            next_sent_beam = []

            # next tokens for this sentence
            for beam_token_rank, (beam_token_id, beam_token_score, constraint, num_met) in enumerate(
                zip(next_tokens[batch_idx], next_scores[batch_idx], constraints[batch_idx], num_mets[batch_idx])
            ):
                # get beam and token IDs
                beam_id = beam_token_id // vocab_size
                token_id = beam_token_id % vocab_size

                effective_beam_id = batch_idx * num_beams + beam_id
                sentence_end = token_id.item() in constraint.eos()
                # add to generated hypotheses if end of sentence or last iteration
                if ((eos_token_id is not None) and (token_id.item() == eos_token_id)) or sentence_end:
                    # if beam_token does not belong to top num_beams tokens, it should not be added
                    is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
                    # if is_beam_token_worse_than_top_num_beams:
                    #     continue
                    generated_hyps[batch_idx].add(
                        torch.cat((input_ids[effective_beam_id], token_id.view([1]))), beam_token_score.item(), num_met,
                    )
                else:
                    # add next predicted token since it is not eos_token
                    next_sent_beam.append((beam_token_score, token_id, effective_beam_id, constraint, num_met))

                # once the beam for next step is full, don't add more tokens to it.
                if len(next_sent_beam) == num_beams:
                    break

            # Check if were done so that we can save a pad step if all(done)
            done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
                next_scores[batch_idx][:beam_token_rank + 1].max().item(), cur_len=cur_len
            ) or not next_sent_beam

            if len(next_sent_beam) < num_beams:
                if next_sent_beam:
                    pad_candidate = next_sent_beam[-1]
                elif done[batch_idx]:
                    pad_candidate = (0, pad_token_id, 0, None, -1)
                else:
                    raise ValueError('impossible search state')
                next_sent_beam += [pad_candidate] * (num_beams - len(next_sent_beam))

            # update next beam content
            assert len(next_sent_beam) == num_beams, "Beam should always be full"
            next_batch_beam.extend(next_sent_beam)
            assert len(next_batch_beam) == num_beams * (batch_idx + 1), "We should have added num_beams each step"

        # stop when we are done with each sentence
        if all(done):
            break

        # sanity check / prepare next batch
        assert len(next_batch_beam) == batch_size * num_beams
        beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
        beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
        beam_idx = input_ids.new([x[2] for x in next_batch_beam])
        constraints = [x[3] for x in next_batch_beam]
        num_mets = [x[4] for x in next_batch_beam]

        # re-order batch and update current length
        input_ids = input_ids[beam_idx, :]
        input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
        position_ids = position_ids[beam_idx, :]
        position_ids = torch.cat([position_ids, (position_ids[:, -1] + 1).unsqueeze(-1)], dim=1)
        cur_len = cur_len + 1

        # re-order internal states
        if past is not None:
            past = self._reorder_cache(past, beam_idx)

        # extend attention_mask for new generated input if only decoder
        if self.config.is_encoder_decoder is False:
            attention_mask = torch.cat(
                [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
            )

    # finalize all open beam hypotheses and add to generated hypotheses
    for batch_idx in range(batch_size):
        if done[batch_idx]:
            continue

        # test that beam scores match previously calculated scores if not eos and batch_idx not done
        if eos_token_id is not None and all(
            (token_id % vocab_size).item() not in cons_eos for token_id in next_tokens[batch_idx]
        ):
            assert torch.all(
                next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
            ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
                next_scores[:, :num_beams][batch_idx], beam_scores.view(batch_size, num_beams)[batch_idx],
            )

        # need to add best num_beams hypotheses to generated hyps
        for beam_id in range(num_beams):
            effective_beam_id = batch_idx * num_beams + beam_id
            final_score = beam_scores[effective_beam_id].item()
            final_tokens = input_ids[effective_beam_id]
            final_num_met = num_mets[effective_beam_id]
            generated_hyps[batch_idx].add(final_tokens, final_score, final_num_met)

    # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
    output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
    output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences

    # select the best hypotheses
    sent_lengths = input_ids.new(output_batch_size)
    best, best_scores, best_sum_logprobs = [], [], []

    # retrieve best hypotheses
    for i, hypotheses in enumerate(generated_hyps):
        sorted_hyps = sorted(hypotheses.beams, key=lambda x: (x[2], x[0]), reverse=True)
        for j in range(output_num_return_sequences_per_batch):
            effective_batch_idx = output_num_return_sequences_per_batch * i + j
            best_score, best_hyp, _ = sorted_hyps[0]
            sent_lengths[effective_batch_idx] = len(best_hyp)
            best.append(best_hyp)
            best_scores.append(best_score)
            best_sum_logprobs.append(best_score * (len(best_hyp) ** length_penalty))

    # shorter batches are padded
    if sent_lengths.min().item() != sent_lengths.max().item():
        assert pad_token_id is not None, "`Pad_token_id` has to be defined"
        sent_max_len = min(sent_lengths.max().item() + 1, max_length)
        decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)

        # fill with hypothesis and eos_token_id if necessary
        for i, hypo in enumerate(best):
            decoded[i, : sent_lengths[i]] = hypo
            if sent_lengths[i] < max_length:
                decoded[i, sent_lengths[i]] = eos_token_id
    else:
        # none of the hypotheses have an eos_token
        assert (len(hypo) == max_length for hypo in best)
        decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)

    return decoded, best_scores, best_sum_logprobs


def _reorder_cache(past: Tuple, beam_idx: Tensor) -> Tuple[Tensor]:
    return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)


def _use_cache(self, outputs, use_cache):
    """During generation, decide whether to pass the `past` variable to the next forward pass."""
    if len(outputs) <= 1 or use_cache is False:
        return False
    if hasattr(self.config, "mem_len") and self.config.mem_len == 0:
        return False
    return True


def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int) -> None:
    """Copied from fairseq for no_repeat_ngram in beam_search"""
    if cur_len + 1 < no_repeat_ngram_size:
        # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
        return [[] for _ in range(num_hypos)]
    generated_ngrams = [{} for _ in range(num_hypos)]
    for idx in range(num_hypos):
        gen_tokens = prev_input_ids[idx].tolist()
        generated_ngram = generated_ngrams[idx]
        for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
            prev_ngram_tuple = tuple(ngram[:-1])
            generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]

    def _get_generated_ngrams(hypo_idx):
        # Before decoding the next token, prevent decoding of ngrams that have already appeared
        start_idx = cur_len + 1 - no_repeat_ngram_size
        ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
        return generated_ngrams[hypo_idx].get(ngram_idx, [])

    banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
    return banned_tokens


def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]) -> Iterable[int]:
    banned_tokens = []

    def _tokens_match(prev_tokens, tokens):
        if len(tokens) == 0:
            # if bad word tokens is just one token always ban it
            return True
        if len(tokens) > len(prev_input_ids):
            # if bad word tokens are longer then prev input_ids they can't be equal
            return False

        if prev_tokens[-len(tokens) :] == tokens:
            # if tokens match
            return True
        else:
            return False

    for prev_input_ids_slice in prev_input_ids:
        banned_tokens_slice = []

        for banned_token_seq in bad_words_ids:
            assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format(
                bad_words_ids
            )

            if _tokens_match(prev_input_ids_slice.tolist(), banned_token_seq[:-1]) is False:
                # if tokens do not match continue
                continue

            banned_tokens_slice.append(banned_token_seq[-1])

        banned_tokens.append(banned_tokens_slice)

    return banned_tokens

In [7]:
# beam_search (baseline)
logger = logging.getLogger(__name__)


def main_no_neuro():
    parser = argparse.ArgumentParser()

    parser.add_argument("--model_name", type=str, help="pretrained language model to use", default="gpt2-medium")
    parser.add_argument("--input_path", type=str, help="path of input file", default="/kaggle/input/commonsenseqa-no-neuro-fs/nl_questions_fs.txt")
    parser.add_argument("--output_file", type=str, help="output file", default="/kaggle/working/baseline_few_shot_gpt2-medium.txt")

    parser.add_argument('--batch_size', type=int, default=8,
                        help="Batch size for decoding.")
    parser.add_argument('--beam_size', type=int, default=8,
                        help="Beam size for searching")
    parser.add_argument('--max_tgt_length', type=int, default=500,
                        help="maximum length of decoded sentences")
    parser.add_argument('--min_tgt_length', type=int, default=1,
                        help="minimum length of decoded sentences")
    parser.add_argument('--ngram_size', type=int, default=3,
                        help='all ngrams can only occur once')
    parser.add_argument('--length_penalty', type=float, default=1,
                        help="length penalty for beam search")
    parser.add_argument('-f', type=str, default=10,
                        help="optional early stop if all constraints are satisfied")

    args = parser.parse_args()

    print(f"Decoding with: {args.model_name}")
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    model = AutoModelWithLMHead.from_pretrained(args.model_name)

    free_gpu_cache()
    model.eval()
    model = model.to('cuda')

    with open(args.input_path) as fin:
        input_lines = [line.strip() for line in fin.read().splitlines()]
    # TODO: TEMP
#     input_lines = input_lines[:10]
    input_lines = [tokenizer.tokenize(x) for x in input_lines]
    input_lines = sorted(list(enumerate(input_lines)),
                         key=lambda x: len(x[1]))
    output_lines = [""] * len(input_lines)
    total_batch = math.ceil(len(input_lines) / args.batch_size)
    next_i = 0

    with tqdm(total=total_batch) as pbar:
        while next_i < len(input_lines):
            full_chunk = input_lines[next_i:next_i + args.batch_size]
            prompt_tokens_num = sorted(set([len(x[1]) for x in full_chunk]))
            step_len = args.batch_size
            if len(prompt_tokens_num) > 1:
                step_len = len([x for x in full_chunk if len(x[1]) == prompt_tokens_num[0]])

            _chunk = input_lines[next_i:next_i + step_len]
            buf_id = [x[0] for x in _chunk]
            buf = [x[1] for x in _chunk]
            next_i += step_len
            input_ids = torch.stack([torch.from_numpy(np.array(tokenizer.convert_tokens_to_ids(x))) for x in buf])
            max_gen_length = list(map(lambda x: x.shape[0], input_ids))
            input_ids = input_ids.to('cuda')

            outputs = model.generate(input_ids=input_ids,
                                     min_length=args.min_tgt_length,
                                     max_length=prompt_tokens_num[-1] + 10,
                                     num_beams=args.beam_size,
                                     no_repeat_ngram_size=args.ngram_size,
                                     length_penalty=args.length_penalty)
            prompt = [tokenizer.convert_tokens_to_string(x) for x in buf]
            output_sequences = [prompt[i] + tokenizer.decode(o).split(prompt[i])[-1].split('<|endoftext|>')[0].rstrip()
                                for i, o in enumerate(outputs)]

            for i in range(len(buf)):
                output_lines[buf_id[i]] = output_sequences[i].replace("\n", "")
            pbar.update(1)

    with open(args.output_file, "w", encoding="utf-8") as fout:
        for l in output_lines:
            fout.write(l)
            fout.write("\n\n")


In [11]:
logger = logging.getLogger(__name__)


def main_neuro():
    parser = argparse.ArgumentParser()

    parser.add_argument("--model_name", type=str, help="pretrained language model to use", default='gpt2-medium')
    parser.add_argument("--output_file", type=str, help="output file", default="/kaggle/working/neuro_few_shot_gpt2-medium.txt")
    parser.add_argument("--constraint_file", type=str, help="constraint file", default="/kaggle/input/constr/constraints_single_clause.json")
    parser.add_argument("--input_path", type=str, help="initialization of decoding", default='/kaggle/input/commonsenseqa-no-neuro-fs/nl_questions_fs.txt')

    parser.add_argument('--batch_size', type=int, default=1,
                        help="Batch size for decoding.")
    parser.add_argument('--beam_size', type=int, default=8,
                        help="Beam size for searching")
    parser.add_argument('--max_tgt_length', type=int, default=32,
                        help="maximum length of decoded sentences")
    parser.add_argument('--min_tgt_length', type=int, default=0,
                        help="minimum length of decoded sentences")
    parser.add_argument('--ngram_size', type=int, default=3,
                        help='all ngrams can only occur once')
    parser.add_argument('--length_penalty', type=float, default=0.2,
                        help="length penalty for beam search")

    parser.add_argument('--prune_factor', type=int, default=500000,
                        help="fraction of candidates to keep based on score")
    parser.add_argument('--sat_tolerance', type=int, default=2,
                        help="minimum satisfied clause of valid candidates")
    parser.add_argument('--beta', type=float, default=1.25,
                        help="reward factor for in progress constraint")
    parser.add_argument('--early_stop', type=float, default=10,
                        help="optional early stop if all constraints are satisfied")
    parser.add_argument('-f', type=str, default=10,
                        help="optional early stop if all constraints are satisfied")

    args = parser.parse_args()
    print(args)

    print(f"Decoding with: {args.model_name}")
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    model = AutoModelWithLMHead.from_pretrained(args.model_name)

    torch.cuda.empty_cache()
    model.eval()
    model = model.to('cuda')

    period_id = [tokenizer.convert_tokens_to_ids('.')]
    period_id.append(tokenizer.convert_tokens_to_ids('Ġ.'))
    eos_ids = [tokenizer.eos_token_id] + period_id
    PAD_ID = tokenizer.convert_tokens_to_ids('<pad>')
    bad_token = [':', "'", '-', '_', '@', 'Ċ', 'Ġ:', 'Ġwho', "'s"]
    bad_words_ids = [tokenizer.convert_tokens_to_ids([t]) for t in bad_token]

    with open(args.input_path) as fin:
        input_lines = [line.strip() for line in fin.read().splitlines()]

    def read_constraints(file_name):
        cons_list = []
        with open(file_name, 'r') as f:
            for i, line in enumerate(f):
                cons = []
                for concept in json.loads(line):
                    cons.append([f' {c}' for c in concept if c.islower()])
                cons_list.append(cons)
        return cons_list

    constraints_list = read_constraints(args.constraint_file)
#     TODO: TEMP
#     input_lines = input_lines[:50]
#     for index, line in enumerate(input_lines):
#         line = "[SEP]".join(line.split("[SEP]")[2:])
#         input_lines[index] = line
    input_lines = [tokenizer.convert_tokens_to_ids(tokenizer.tokenize(x)) for x in input_lines]
    constraints_list = tokenize_constraints(tokenizer, constraints_list)

    if path.exists(args.output_file):
        count = len(open(args.output_file, 'r').readlines())
        fout = Path(args.output_file).open("a", encoding="utf-8")
        input_lines = input_lines[count:]
        constraints_list = constraints_list[count:]
    else:
        fout = Path(args.output_file).open("w", encoding="utf-8")
    total_batch = math.ceil(len(input_lines) / args.batch_size)

    next_i = 0

    with tqdm(total=total_batch) as pbar:
        while next_i < len(input_lines):
            _chunk = input_lines[next_i:next_i + args.batch_size]
            constraints = init_batch(raw_constraints=constraints_list[next_i:next_i + args.batch_size],
                                     beam_size=args.beam_size,
                                     eos_id=eos_ids)
            buf = _chunk
            next_i += args.batch_size
            
            max_len = max([len(x) for x in buf])
            buf = [x + [PAD_ID] * (max_len - len(x)) for x in buf]

            input_ids = torch.stack([torch.from_numpy(np.array(x)) for x in buf])
            input_ids = input_ids.to('cuda')
            attention_mask = (~torch.eq(input_ids, PAD_ID)).int()
            attention_mask = attention_mask.to('cuda')
            
            free_gpu_cache()
            advanced_constraints = []
            for j, init_cons in enumerate(constraints):
                adv_cons = init_cons
                for token in _chunk[j // args.beam_size]:
                    adv_cons = adv_cons.advance(token)
                advanced_constraints.append(adv_cons)
            free_gpu_cache()
            outputs, _, _ = generate(self=model,
                                     input_ids=input_ids,
                                     attention_mask=attention_mask,
                                     pad_token_id=PAD_ID,
                                     bad_words_ids=bad_words_ids,
                                     min_length=args.min_tgt_length,
                                     max_length=max_len + 10,
                                     num_beams=args.beam_size,
                                     no_repeat_ngram_size=args.ngram_size,
                                     length_penalty=args.length_penalty,
                                     constraints=advanced_constraints,
                                     prune_factor=args.prune_factor,
                                     sat_tolerance=args.sat_tolerance,
                                     beta=args.beta,
                                     early_stop=args.early_stop,
                                     )

            prompt = [tokenizer.decode(x) for x in buf]
            output_sequences = [prompt[i] + tokenizer.decode(o).split(prompt[i])[-1].split('<|endoftext|>')[0].rstrip()
                                for i, o in enumerate(outputs)]
            for hypothesis in output_sequences:
                fout.write(hypothesis.strip().replace('<|endoftext|>', '') + "\n")
                fout.flush()

            pbar.update(1)

In [None]:
# main_no_neuro()

main_neuro()

Namespace(model_name='gpt2-medium', output_file='/kaggle/working/neuro_few_shot_gpt2-medium.txt', constraint_file='/kaggle/input/constr/constraints_single_clause.json', input_path='/kaggle/input/commonsenseqa-no-neuro-fs/nl_questions_fs.txt', batch_size=1, beam_size=8, max_tgt_length=32, min_tgt_length=0, ngram_size=3, length_penalty=0.2, prune_factor=500000, sat_tolerance=2, beta=1.25, early_stop=10, f='/root/.local/share/jupyter/runtime/kernel-df34d823-5073-462a-84e2-4763e8c56939.json')
Decoding with: gpt2-medium


 16%|█▌        | 192/1221 [39:05<3:25:16, 11.97s/it]

In [10]:
# Clear output folder
import os

def remove_folder_contents(folder):
    for the_file in os.listdir(folder):
        file_path = os.path.join(folder, the_file)
        try:
            if os.path.isfile(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                remove_folder_contents(file_path)
                os.rmdir(file_path)
        except Exception as e:
            print(e)

# folder_path = '/kaggle/working'
# remove_folder_contents(folder_path)