In [1]:
!pip install transformers
!pip install datasets

Collecting transformers
  Downloading transformers-4.17.0-py3-none-any.whl (3.8 MB)
[K     |████████████████████████████████| 3.8 MB 5.5 MB/s 
Collecting tokenizers!=0.11.3,>=0.11.1
  Downloading tokenizers-0.11.6-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.5 MB)
[K     |████████████████████████████████| 6.5 MB 24.3 MB/s 
[?25hCollecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 35.6 MB/s 
Collecting sacremoses
  Downloading sacremoses-0.0.47-py2.py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 31.2 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.4.0-py3-none-any.whl (67 kB)
[K     |████████████████████████████████| 67 kB 4.6 MB/s 
Installing collected packages: pyyaml, tokenizers, sacremoses, huggingface-hub, transformers
  Attempting uninstall: pyyaml
    Foun

In [2]:
# Varint encoding and decoding functions

import sys

# Adopted from
# https://github.com/bright-tools/varints

ONE_BYTE_LIMIT = 240
TWO_BYTE_LIMIT = 2287
THREE_BYTE_LIMIT = 67823

FOUR_BYTE_LIMIT = 16777215
FIVE_BYTE_LIMIT = 4294967295
SIX_BYTE_LIMIT = 1099511627775
SEVEN_BYTE_LIMIT = 281474976710655
EIGHT_BYTE_LIMIT = 72057594037927935
NINE_BYTE_LIMIT = 18446744073709551615
THREE_BYTE_HEADER = 249
FOUR_BYTE_HEADER = 250
FIVE_BYTE_HEADER = 251
SIX_BYTE_HEADER = 252
SEVEN_BYTE_HEADER = 253
EIGHT_BYTE_HEADER = 254
NINE_BYTE_HEADER = 255
BYTE_VALS = 256
SHORT_VALS = 65536

BUCKET_OFFSET = 2

minint = 0
maxint = NINE_BYTE_LIMIT

buckets = [ { 'limit': FOUR_BYTE_LIMIT,
              'header': FOUR_BYTE_HEADER },
            { 'limit': FIVE_BYTE_LIMIT,
              'header': FIVE_BYTE_HEADER },
            { 'limit': SIX_BYTE_LIMIT,
              'header': SIX_BYTE_HEADER },
            { 'limit': SEVEN_BYTE_LIMIT,
              'header': SEVEN_BYTE_HEADER },
            { 'limit': EIGHT_BYTE_LIMIT,
              'header': EIGHT_BYTE_HEADER },
            { 'limit': NINE_BYTE_LIMIT,
              'header': NINE_BYTE_HEADER },
          ]


def writeToFile(payload, filename):
    with open(filename, "wb") as f:
        f.write(varint_encode(payload))

def readFromFile(filename):
    with open(filename, "rb") as f:
        bytes = f.read()
    return varint_decode(bytes)

def varint_encode( num ):
    return generic_encode( num, funcs )

def encode_int( num ):
    ret_val = None
    if num < 0:
        raise ValueError("Negative numbers not handled")

    if( num <= ONE_BYTE_LIMIT ):
        ret_val = varint_storage( num )
    elif( num <= TWO_BYTE_LIMIT ):
        top = num-ONE_BYTE_LIMIT
        ret_val = varint_storage( (top // BYTE_VALS)+ONE_BYTE_LIMIT+1 ) + \
                  varint_storage( top % BYTE_VALS )
    elif( num <= THREE_BYTE_LIMIT ):
        top = num-(TWO_BYTE_LIMIT+1)
        ret_val = varint_storage( THREE_BYTE_HEADER ) + \
                  varint_storage( top // BYTE_VALS ) + \
                  varint_storage( top % BYTE_VALS )
    else:
        start = 0

        # Work out how many bytes are needed to store this value
        while(( start < len( buckets )) and
              ( num > buckets[start]['limit'])):
            start = start + 1

        if( start == len( buckets )):
            raise ValueError("Too large")

        ret_val = varint_storage( buckets[start]['header'] )
        mod = (buckets[start]['limit']+1) // BYTE_VALS
        start = start + BUCKET_OFFSET

        while( start >= 0 ):
            start = start - 1
            ret_val = ret_val + varint_storage( num // mod )
            num = num % mod
            mod = mod // BYTE_VALS

    return ret_val

def varint_decode( num ):
    return generic_decode( num, funcs )

def decode_val( num ):
    ret_val = None
    bytes_used = 1
    first = store_to_num( num[ 0 ] )
    if( first <= ONE_BYTE_LIMIT ):
        ret_val = first
    elif( first < THREE_BYTE_HEADER ):
        second = store_to_num( num[ 1 ] )
        ret_val = ONE_BYTE_LIMIT+(BYTE_VALS*(first-(ONE_BYTE_LIMIT+1)))+second
        bytes_used = 2
    elif( first == THREE_BYTE_HEADER ):
        second = store_to_num( num[ 1 ] )
        third = store_to_num( num[ 2 ] )
        ret_val = (TWO_BYTE_LIMIT+1)+(BYTE_VALS*second)+third
        bytes_used = 3
    else:
        data_bytes = first-247
        start = data_bytes - 1
        ret_val = 0
        i = 1

        mod = (buckets[start-BUCKET_OFFSET]['limit']+1) // BYTE_VALS

        while( start >= 0 ):
            ret_val = ret_val + (mod * store_to_num( num[ i ] )) 
            i = i + 1
            start = start - 1
            mod = mod // BYTE_VALS

        bytes_used = data_bytes + 1

    return (ret_val, bytes_used)

funcs = { 'decode_val': decode_val,
          'encode_int': encode_int }

if sys.version_info[0] > 2:
    def empty_varint_storage():
        return bytes()
    def varint_storage(b):
        return bytes((b, ))
    def store_to_num(b):
        return b
    def num_types():
        return (int)
else:
    def empty_varint_storage():
        return ""
    def varint_storage(b):
        return chr(b)
    def store_to_num(b):
        return ord(b)
    def num_types():
        return (int,long)

def dump( num ):
    print( "Len: {}".format( len(num) ))
    for element in num:
        print( "B: {}".format( store_to_num(element) ))

def generic_encode( num, funcs ):
    ret_val = None
    if( isinstance(num, list)):
        ret_val = encode_list( num, funcs )
    elif( isinstance( num, num_types() )):
        ret_val = funcs['encode_int']( num )
    return ret_val

def encode_list( num, funcs ):
    ret_val = empty_varint_storage()
    for val in num:
        ret_val = ret_val + funcs['encode_int']( val )
    return ret_val

def generic_decode( num, funcs ):
    ret_val = None
    if( isinstance(num, (str,bytes))):
        ptr = 0
        while ptr < len( num ):
            (int_val, bytes_used) = funcs['decode_val']( num[ptr:] )
            ptr = ptr + bytes_used
            if ret_val is None:
                ret_val = int_val
            else:
                if isinstance( ret_val, num_types()):
                    ret_val = [ret_val]
                ret_val.append( int_val )
    return ret_val


In [3]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
import gzip

model = GPT2LMHeadModel.from_pretrained('gpt2').to('cuda')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
_ = model.eval()

Downloading:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/523M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/0.99M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

# Arithmetic Coding (Adapated from Project Nayuki)

In [4]:
# 
# Reference arithmetic coding
# Copyright (c) Project Nayuki
# 
# https://www.nayuki.io/page/reference-arithmetic-coding
# https://github.com/nayuki/Reference-arithmetic-coding
# 


# ---- Arithmetic coding core classes ----

# Provides the state and behaviors that arithmetic coding encoders and decoders share.
class ArithmeticCoderBase:
	
	# Constructs an arithmetic coder, which initializes the code range.
	def __init__(self, numbits):
		if numbits < 1:
			raise ValueError("State size out of range")
		
		# -- Configuration fields --
		# Number of bits for the 'low' and 'high' state variables. Must be at least 1.
		# - Larger values are generally better - they allow a larger maximum frequency total (maximum_total),
		#   and they reduce the approximation error inherent in adapting fractions to integers;
		#   both effects reduce the data encoding loss and asymptotically approach the efficiency
		#   of arithmetic coding using exact fractions.
		# - But larger state sizes increase the computation time for integer arithmetic,
		#   and compression gains beyond ~30 bits essentially zero in real-world applications.
		# - Python has native bigint arithmetic, so there is no upper limit to the state size.
		#   For Java and C++ where using native machine-sized integers makes the most sense,
		#   they have a recommended value of num_state_bits=32 as the most versatile setting.
		self.num_state_bits = numbits
		# Maximum range (high+1-low) during coding (trivial), which is 2^num_state_bits = 1000...000.
		self.full_range = 1 << self.num_state_bits
		# The top bit at width num_state_bits, which is 0100...000.
		self.half_range = self.full_range >> 1  # Non-zero
		# The second highest bit at width num_state_bits, which is 0010...000. This is zero when num_state_bits=1.
		self.quarter_range = self.half_range >> 1  # Can be zero
		# Minimum range (high+1-low) during coding (non-trivial), which is 0010...010.
		self.minimum_range = self.quarter_range + 2  # At least 2
		# Maximum allowed total from a frequency table at all times during coding. This differs from Java
		# and C++ because Python's native bigint avoids constraining the size of intermediate computations.
		self.maximum_total = self.minimum_range
		# Bit mask of num_state_bits ones, which is 0111...111.
		self.state_mask = self.full_range - 1
		
		# -- State fields --
		# Low end of this arithmetic coder's current range. Conceptually has an infinite number of trailing 0s.
		self.low = 0
		# High end of this arithmetic coder's current range. Conceptually has an infinite number of trailing 1s.
		self.high = self.state_mask
	
	
	# Updates the code range (low and high) of this arithmetic coder as a result
	# of processing the given symbol with the given frequency table.
	# Invariants that are true before and after encoding/decoding each symbol
	# (letting full_range = 2^num_state_bits):
	# - 0 <= low <= code <= high < full_range. ('code' exists only in the decoder.)
	#   Therefore these variables are unsigned integers of num_state_bits bits.
	# - low < 1/2 * full_range <= high.
	#   In other words, they are in different halves of the full range.
	# - (low < 1/4 * full_range) || (high >= 3/4 * full_range).
	#   In other words, they are not both in the middle two quarters.
	# - Let range = high - low + 1, then full_range/4 < minimum_range
	#   <= range <= full_range. These invariants for 'range' essentially
	#   dictate the maximum total that the incoming frequency table can have.
	def update(self, freqs, symbol):
		# State check
		low = self.low
		high = self.high
		if low >= high or (low & self.state_mask) != low or (high & self.state_mask) != high:
			raise AssertionError("Low or high out of range")
		range = high - low + 1
		if not (self.minimum_range <= range <= self.full_range):
			raise AssertionError("Range out of range")
		
		# Frequency table values check
		total = freqs.get_total()
		symlow = freqs.get_low(symbol)
		symhigh = freqs.get_high(symbol)
		if symlow == symhigh:
			raise ValueError("Symbol has zero frequency")
		if total > self.maximum_total:
			raise ValueError("Cannot code symbol because total is too large")
		
		# Update range
		newlow  = low + symlow  * range // total
		newhigh = low + symhigh * range // total - 1
		self.low = newlow
		self.high = newhigh
		
		# While low and high have the same top bit value, shift them out
		while ((self.low ^ self.high) & self.half_range) == 0:
			self.shift()
			self.low  = ((self.low  << 1) & self.state_mask)
			self.high = ((self.high << 1) & self.state_mask) | 1
		# Now low's top bit must be 0 and high's top bit must be 1
		
		# While low's top two bits are 01 and high's are 10, delete the second highest bit of both
		while (self.low & ~self.high & self.quarter_range) != 0:
			self.underflow()
			self.low = (self.low << 1) ^ self.half_range
			self.high = ((self.high ^ self.half_range) << 1) | self.half_range | 1
	
	
	# Called to handle the situation when the top bit of 'low' and 'high' are equal.
	def shift(self):
		raise NotImplementedError()
	
	
	# Called to handle the situation when low=01(...) and high=10(...).
	def underflow(self):
		raise NotImplementedError()



# Encodes symbols and writes to an arithmetic-coded bit stream.
class ArithmeticEncoder(ArithmeticCoderBase):
	
	# Constructs an arithmetic coding encoder based on the given bit output stream.
	def __init__(self, numbits, bitout):
		super(ArithmeticEncoder, self).__init__(numbits)
		# The underlying bit output stream.
		self.output = bitout
		# Number of saved underflow bits. This value can grow without bound.
		self.num_underflow = 0
	
	
	# Encodes the given symbol based on the given frequency table.
	# This updates this arithmetic coder's state and may write out some bits.
	def write(self, freqs, symbol):
		if not isinstance(freqs, CheckedFrequencyTable):
			freqs = CheckedFrequencyTable(freqs)
		self.update(freqs, symbol)
	
	
	# Terminates the arithmetic coding by flushing any buffered bits, so that the output can be decoded properly.
	# It is important that this method must be called at the end of the each encoding process.
	# Note that this method merely writes data to the underlying output stream but does not close it.
	def finish(self):
		self.output.write(1)
	
	
	def shift(self):
		bit = self.low >> (self.num_state_bits - 1)
		self.output.write(bit)
		
		# Write out the saved underflow bits
		for _ in range(self.num_underflow):
			self.output.write(bit ^ 1)
		self.num_underflow = 0
	
	
	def underflow(self):
		self.num_underflow += 1



# Reads from an arithmetic-coded bit stream and decodes symbols.
class ArithmeticDecoder(ArithmeticCoderBase):
	
	# Constructs an arithmetic coding decoder based on the
	# given bit input stream, and fills the code bits.
	def __init__(self, numbits, bitin):
		super(ArithmeticDecoder, self).__init__(numbits)
		# The underlying bit input stream.
		self.input = bitin
		# The current raw code bits being buffered, which is always in the range [low, high].
		self.code = 0
		for _ in range(self.num_state_bits):
			self.code = self.code << 1 | self.read_code_bit()
	
	
	# Decodes the next symbol based on the given frequency table and returns it.
	# Also updates this arithmetic coder's state and may read in some bits.
	def read(self, freqs):
		if not isinstance(freqs, CheckedFrequencyTable):
			freqs = CheckedFrequencyTable(freqs)
		
		# Translate from coding range scale to frequency table scale
		total = freqs.get_total()
		if total > self.maximum_total:
			raise ValueError("Cannot decode symbol because total is too large")
		range = self.high - self.low + 1
		offset = self.code - self.low
		value = ((offset + 1) * total - 1) // range
		assert value * range // total <= offset
		assert 0 <= value < total
		
		# A kind of binary search. Find highest symbol such that freqs.get_low(symbol) <= value.
		start = 0
		end = freqs.get_symbol_limit()
		while end - start > 1:
			middle = (start + end) >> 1
			if freqs.get_low(middle) > value:
				end = middle
			else:
				start = middle
		assert start + 1 == end
		
		symbol = start
		assert freqs.get_low(symbol) * range // total <= offset < freqs.get_high(symbol) * range // total
		self.update(freqs, symbol)
		if not (self.low <= self.code <= self.high):
			raise AssertionError("Code out of range")
		return symbol
	
	
	def shift(self):
		self.code = ((self.code << 1) & self.state_mask) | self.read_code_bit()
	
	
	def underflow(self):
		self.code = (self.code & self.half_range) | ((self.code << 1) & (self.state_mask >> 1)) | self.read_code_bit()
	
	
	# Returns the next bit (0 or 1) from the input stream. The end
	# of stream is treated as an infinite number of trailing zeros.
	def read_code_bit(self):
		temp = self.input.read()
		if temp == -1:
			temp = 0
		return temp



# ---- Frequency table classes ----

# A table of symbol frequencies. The table holds data for symbols numbered from 0
# to get_symbol_limit()-1. Each symbol has a frequency, which is a non-negative integer.
# Frequency table objects are primarily used for getting cumulative symbol
# frequencies. These objects can be mutable depending on the implementation.
class FrequencyTable:
	
	# Returns the number of symbols in this frequency table, which is a positive number.
	def get_symbol_limit(self):
		raise NotImplementedError()
	
	# Returns the frequency of the given symbol. The returned value is at least 0.
	def get(self, symbol):
		raise NotImplementedError()
	
	# Sets the frequency of the given symbol to the given value.
	# The frequency value must be at least 0.
	def set(self, symbol, freq):
		raise NotImplementedError()
	
	# Increments the frequency of the given symbol.
	def increment(self, symbol):
		raise NotImplementedError()
	
	# Returns the total of all symbol frequencies. The returned value is at
	# least 0 and is always equal to get_high(get_symbol_limit() - 1).
	def get_total(self):
		raise NotImplementedError()
	
	# Returns the sum of the frequencies of all the symbols strictly
	# below the given symbol value. The returned value is at least 0.
	def get_low(self, symbol):
		raise NotImplementedError()
	
	# Returns the sum of the frequencies of the given symbol
	# and all the symbols below. The returned value is at least 0.
	def get_high(self, symbol):
		raise NotImplementedError()



# An immutable frequency table where every symbol has the same frequency of 1.
# Useful as a fallback model when no statistics are available.
class FlatFrequencyTable(FrequencyTable):
	
	# Constructs a flat frequency table with the given number of symbols.
	def __init__(self, numsyms):
		if numsyms < 1:
			raise ValueError("Number of symbols must be positive")
		self.numsymbols = numsyms  # Total number of symbols, which is at least 1
	
	# Returns the number of symbols in this table, which is at least 1.
	def get_symbol_limit(self):
		return self.numsymbols
	
	# Returns the frequency of the given symbol, which is always 1.
	def get(self, symbol):
		self._check_symbol(symbol)
		return 1
	
	# Returns the total of all symbol frequencies, which is
	# always equal to the number of symbols in this table.
	def get_total(self):
		return self.numsymbols
	
	# Returns the sum of the frequencies of all the symbols strictly below
	# the given symbol value. The returned value is equal to 'symbol'.
	def get_low(self, symbol):
		self._check_symbol(symbol)
		return symbol
	
	
	# Returns the sum of the frequencies of the given symbol and all
	# the symbols below. The returned value is equal to 'symbol' + 1.
	def get_high(self, symbol):
		self._check_symbol(symbol)
		return symbol + 1
	
	
	# Returns silently if 0 <= symbol < numsymbols, otherwise raises an exception.
	def _check_symbol(self, symbol):
		if 0 <= symbol < self.numsymbols:
			return
		else:
			raise ValueError("Symbol out of range")
	
	# Returns a string representation of this frequency table. The format is subject to change.
	def __str__(self):
		return "FlatFrequencyTable={}".format(self.numsymbols)
	
	# Unsupported operation, because this frequency table is immutable.
	def set(self, symbol, freq):
		raise NotImplementedError()
	
	# Unsupported operation, because this frequency table is immutable.
	def increment(self, symbol):
		raise NotImplementedError()



# A mutable table of symbol frequencies. The number of symbols cannot be changed
# after construction. The current algorithm for calculating cumulative frequencies
# takes linear time, but there exist faster algorithms such as Fenwick trees.
class SimpleFrequencyTable(FrequencyTable):
	
	# Constructs a simple frequency table in one of two ways:
	# - SimpleFrequencyTable(sequence):
	#   Builds a frequency table from the given sequence of symbol frequencies.
	#   There must be at least 1 symbol, and no symbol has a negative frequency.
	# - SimpleFrequencyTable(freqtable):
	#   Builds a frequency table by copying the given frequency table.
	def __init__(self, freqs):
		if isinstance(freqs, FrequencyTable):
			numsym = freqs.get_symbol_limit()
			self.frequencies = [freqs.get(i) for i in range(numsym)]
		else:  # Assume it is a sequence type
			self.frequencies = list(freqs)  # Make copy
		
		# 'frequencies' is a list of the frequency for each symbol.
		# Its length is at least 1, and each element is non-negative.
		if len(self.frequencies) < 1:
			raise ValueError("At least 1 symbol needed")
		for freq in self.frequencies:
			if freq < 0:
				raise ValueError("Negative frequency")
		
		# Always equal to the sum of 'frequencies'
		self.total = sum(self.frequencies)
		
		# cumulative[i] is the sum of 'frequencies' from 0 (inclusive) to i (exclusive).
		# Initialized lazily. When it is not None, the data is valid.
		self.cumulative = None
	
	
	# Returns the number of symbols in this frequency table, which is at least 1.
	def get_symbol_limit(self):
		return len(self.frequencies)
	
	
	# Returns the frequency of the given symbol. The returned value is at least 0.
	def get(self, symbol):
		self._check_symbol(symbol)
		return self.frequencies[symbol]
	
	
	# Sets the frequency of the given symbol to the given value. The frequency value
	# must be at least 0. If an exception is raised, then the state is left unchanged.
	def set(self, symbol, freq):
		self._check_symbol(symbol)
		if freq < 0:
			raise ValueError("Negative frequency")
		temp = self.total - self.frequencies[symbol]
		assert temp >= 0
		self.total = temp + freq
		self.frequencies[symbol] = freq
		self.cumulative = None
	
	
	# Increments the frequency of the given symbol.
	def increment(self, symbol):
		self._check_symbol(symbol)
		self.total += 1
		self.frequencies[symbol] += 1
		self.cumulative = None
	
	
	# Returns the total of all symbol frequencies. The returned value is at
	# least 0 and is always equal to get_high(get_symbol_limit() - 1).
	def get_total(self):
		return self.total
	
	
	# Returns the sum of the frequencies of all the symbols strictly
	# below the given symbol value. The returned value is at least 0.
	def get_low(self, symbol):
		self._check_symbol(symbol)
		if self.cumulative is None:
			self._init_cumulative()
		return self.cumulative[symbol]
	
	
	# Returns the sum of the frequencies of the given symbol
	# and all the symbols below. The returned value is at least 0.
	def get_high(self, symbol):
		self._check_symbol(symbol)
		if self.cumulative is None:
			self._init_cumulative()
		return self.cumulative[symbol + 1]
	
	
	# Recomputes the array of cumulative symbol frequencies.
	def _init_cumulative(self):
		cumul = [0]
		sum = 0
		for freq in self.frequencies:
			sum += freq
			cumul.append(sum)
		assert sum == self.total
		self.cumulative = cumul
	
	
	# Returns silently if 0 <= symbol < len(frequencies), otherwise raises an exception.
	def _check_symbol(self, symbol):
		if 0 <= symbol < len(self.frequencies):
			return
		else:
			raise ValueError("Symbol out of range")
	
	
	# Returns a string representation of this frequency table,
	# useful for debugging only, and the format is subject to change.
	def __str__(self):
		result = ""
		for (i, freq) in enumerate(self.frequencies):
			result += "{}\t{}\n".format(i, freq)
		return result



# A wrapper that checks the preconditions (arguments) and postconditions (return value) of all
# the frequency table methods. Useful for finding faults in a frequency table implementation.
class CheckedFrequencyTable(FrequencyTable):
	
	def __init__(self, freqtab):
		# The underlying frequency table that holds the data
		self.freqtable = freqtab
	
	
	def get_symbol_limit(self):
		result = self.freqtable.get_symbol_limit()
		if result <= 0:
			raise AssertionError("Non-positive symbol limit")
		return result
	
	
	def get(self, symbol):
		result = self.freqtable.get(symbol)
		if not self._is_symbol_in_range(symbol):
			raise AssertionError("ValueError expected")
		if result < 0:
			raise AssertionError("Negative symbol frequency")
		return result
	
	
	def get_total(self):
		result = self.freqtable.get_total()
		if result < 0:
			raise AssertionError("Negative total frequency")
		return result
	
	
	def get_low(self, symbol):
		if self._is_symbol_in_range(symbol):
			low   = self.freqtable.get_low (symbol)
			high  = self.freqtable.get_high(symbol)
			if not (0 <= low <= high <= self.freqtable.get_total()):
				raise AssertionError("Symbol low cumulative frequency out of range")
			return low
		else:
			self.freqtable.get_low(symbol)
			raise AssertionError("ValueError expected")
	
	
	def get_high(self, symbol):
		if self._is_symbol_in_range(symbol):
			low   = self.freqtable.get_low (symbol)
			high  = self.freqtable.get_high(symbol)
			if not (0 <= low <= high <= self.freqtable.get_total()):
				raise AssertionError("Symbol high cumulative frequency out of range")
			return high
		else:
			self.freqtable.get_high(symbol)
			raise AssertionError("ValueError expected")
	
	
	def __str__(self):
		return "CheckedFrequencyTable (" + str(self.freqtable) + ")"
	
	
	def set(self, symbol, freq):
		self.freqtable.set(symbol, freq)
		if not self._is_symbol_in_range(symbol) or freq < 0:
			raise AssertionError("ValueError expected")
	
	
	def increment(self, symbol):
		self.freqtable.increment(symbol)
		if not self._is_symbol_in_range(symbol):
			raise AssertionError("ValueError expected")
	
	
	def _is_symbol_in_range(self, symbol):
		return 0 <= symbol < self.get_symbol_limit()



# ---- Bit-oriented I/O streams ----

# A stream of bits that can be read. Because they come from an underlying byte stream,
# the total number of bits is always a multiple of 8. The bits are read in big endian.
class BitInputStream:
	
	# Constructs a bit input stream based on the given byte input stream.
	def __init__(self, inp):
		# The underlying byte stream to read from
		self.input = inp
		# Either in the range [0x00, 0xFF] if bits are available, or -1 if end of stream is reached
		self.currentbyte = 0
		# Number of remaining bits in the current byte, always between 0 and 7 (inclusive)
		self.numbitsremaining = 0
	
	
	# Reads a bit from this stream. Returns 0 or 1 if a bit is available, or -1 if
	# the end of stream is reached. The end of stream always occurs on a byte boundary.
	def read(self):
		if self.currentbyte == -1:
			return -1
		if self.numbitsremaining == 0:
			temp = self.input.read(1)
			if len(temp) == 0:
				self.currentbyte = -1
				return -1
			self.currentbyte = temp[0]
			self.numbitsremaining = 8
		assert self.numbitsremaining > 0
		self.numbitsremaining -= 1
		return (self.currentbyte >> self.numbitsremaining) & 1
	
	
	# Reads a bit from this stream. Returns 0 or 1 if a bit is available, or raises an EOFError
	# if the end of stream is reached. The end of stream always occurs on a byte boundary.
	def read_no_eof(self):
		result = self.read()
		if result != -1:
			return result
		else:
			raise EOFError()
	
	
	# Closes this stream and the underlying input stream.
	def close(self):
		self.input.close()
		self.currentbyte = -1
		self.numbitsremaining = 0



# A stream where bits can be written to. Because they are written to an underlying
# byte stream, the end of the stream is padded with 0's up to a multiple of 8 bits.
# The bits are written in big endian.
class BitOutputStream:
	
	# Constructs a bit output stream based on the given byte output stream.
	def __init__(self, out):
		self.output = out  # The underlying byte stream to write to
		self.currentbyte = 0  # The accumulated bits for the current byte, always in the range [0x00, 0xFF]
		self.numbitsfilled = 0  # Number of accumulated bits in the current byte, always between 0 and 7 (inclusive)
	
	
	# Writes a bit to the stream. The given bit must be 0 or 1.
	def write(self, b):
		if b not in (0, 1):
			raise ValueError("Argument must be 0 or 1")
		self.currentbyte = (self.currentbyte << 1) | b
		self.numbitsfilled += 1
		if self.numbitsfilled == 8:
			towrite = bytes((self.currentbyte,))
			self.output.write(towrite)
			self.currentbyte = 0
			self.numbitsfilled = 0
	
	
	# Closes this stream and the underlying output stream. If called when this
	# bit stream is not at a byte boundary, then the minimum number of "0" bits
	# (between 0 and 7 of them) are written as padding to reach the next byte boundary.
	def close(self):
		while self.numbitsfilled != 0:
			self.write(0)
		self.output.close()


Compress

In [5]:
# 
# Compression application using adaptive arithmetic coding
# 
# Usage: python adaptive-arithmetic-compress.py InputFile OutputFile
# Then use the corresponding adaptive-arithmetic-decompress.py application to recreate the original input file.
# Note that the application starts with a flat frequency table of 257 symbols (all set to a frequency of 1),
# and updates it after each byte encoded. The corresponding decompressor program also starts with a flat
# frequency table and updates it after each byte decoded. It is by design that the compressor and
# decompressor have synchronized states, so that the data can be decompressed properly.
# 
# Copyright (c) Project Nayuki
# 
# https://www.nayuki.io/page/reference-arithmetic-coding
# https://github.com/nayuki/Reference-arithmetic-coding
# 
import contextlib, sys
# import arithmeticcoding


# Command line main application function.
def main(args):
	# Handle command line arguments
	if len(args) != 2:
		sys.exit("Usage: python adaptive-arithmetic-compress.py InputFile OutputFile")
	inputfile, outputfile = args
	
	# Perform file compression
	with open(inputfile, "rb") as inp, \
			contextlib.closing(BitOutputStream(open(outputfile, "wb"))) as bitout:
		compress(inp, bitout)


def compress(inp, bitout):
	initfreqs = FlatFrequencyTable(257)
	freqs = SimpleFrequencyTable(initfreqs)
	enc = ArithmeticEncoder(32, bitout)
	while True:
		# Read and encode one byte
		symbol = inp.read(1)
		if len(symbol) == 0:
			break
		enc.write(freqs, symbol[0])
		freqs.increment(symbol[0])
	enc.write(freqs, 256)  # EOF
	enc.finish()  # Flush remaining code bits


# Main launcher
# if __name__ == "__main__":
# 	main(sys.argv[1 : ])


Decompress

In [6]:
# 
# Decompression application using adaptive arithmetic coding
# 
# Usage: python adaptive-arithmetic-decompress.py InputFile OutputFile
# This decompresses files generated by the adaptive-arithmetic-compress.py application.
# 
# Copyright (c) Project Nayuki
# 
# https://www.nayuki.io/page/reference-arithmetic-coding
# https://github.com/nayuki/Reference-arithmetic-coding
# 

import sys
# import arithmeticcoding


# Command line main application function.
def main(args):
	# Handle command line arguments
	if len(args) != 2:
		sys.exit("Usage: python adaptive-arithmetic-decompress.py InputFile OutputFile")
	inputfile, outputfile = args
	
	# Perform file decompression
	with open(inputfile, "rb") as inp, open(outputfile, "wb") as out:
		bitin = BitInputStream(inp)
		decompress(bitin, out)


def decompress(bitin, out):
	initfreqs = FlatFrequencyTable(257)
	freqs = SimpleFrequencyTable(initfreqs)
	dec = ArithmeticDecoder(32, bitin)
	while True:
		# Decode and write one byte
		symbol = dec.read(freqs)
		if symbol == 256:  # EOF symbol
			break
		out.write(bytes((symbol,)))
		freqs.increment(symbol)


# Main launcher
# if __name__ == "__main__":
# 	main(sys.argv[1 : ])


In [7]:
VOCAB_SIZE = 50257
PAD_TOKEN = 50256

In [8]:
# Neural compression functions

def valid_encodings(shifted_inputs, encoded_msgs, sorted_tokens):
  # At each timestep, use the encoded message to select the tokens at the specified
  # index of the list of sorted tokens to reconstruct the original message.
  # Compare against the original message to ensure they are identical.
  batch_size, token_len, vocab_size = sorted_tokens.size()
  msg_len = token_len - 1


  # Flatten the tensor of sorted tokens to make indexing easier
  # and add offsets to the encoded message to account for this flattening
  sorted_tokens_flat = sorted_tokens.view(batch_size, -1)
  encoded_msgs_offset = encoded_msgs + torch.arange(0,vocab_size*msg_len,vocab_size).to('cuda')
  decoded_msgs_cand = torch.gather(sorted_tokens_flat, 1, encoded_msgs_offset)
  return torch.all(decoded_msgs_cand == shifted_inputs[:, :-1])
  

def trans_encode(tokenized_msgs, attentions, vocab_size):
  """
  Parameters
    tokenized_msgs: shape (batch_size, msg_len)
    attentions: shape (batch_size, msg_len)
    vocab_size: integer

  """
  # Encode
  model.eval()
  with torch.no_grad():
    # In theory, I should be able to avoid the loop because the transformer
    # automatically masks the input. But in practice, this causes the logit
    # outputs to differ slightly between the encoder and decoder
    batch_size, msgs_len = tokenized_msgs.size()
    logits_arr = torch.zeros(batch_size, msgs_len, vocab_size).to('cuda')
    for i in range(msgs_len):
      msgs_slice = tokenized_msgs[:,:i+1]
      attentions_slice = attentions[:,:i+1]
      logits = model(msgs_slice, attention_mask=attentions_slice).logits
      logits_arr[:, i] = logits[:, i]
    
  # Sort the indices of the logits in descending order of logit value.
  # This means that the model's top predicted token is the first
  # element in the sorted list, the second highest predicted token is the 
  # second element, and so on.
  # 
  # Once we have this list of tokens ordered by their probability
  # we can find the ground-truth token in this list, and save its index
  # as the encoding of the token.
  shifted_inputs = torch.roll(tokenized_msgs, -1) # Shift inputs to line up with output
  _, sorted_tokens = torch.sort(logits_arr, dim=2, descending=True, stable=True)
  shifted_inputs_reshaped = shifted_inputs.view(batch_size, msgs_len, 1)
  encoded_msgs = (sorted_tokens == shifted_inputs_reshaped).nonzero()[:,2].reshape(batch_size, -1).to('cuda')
  encoded_msgs = encoded_msgs[:, :-1] # Discard the last index because it overflows the original message
  assert valid_encodings(shifted_inputs, encoded_msgs, sorted_tokens)

  # We need to include the first token as part of the encoded message so that we
  # can bootstrap generation
  encoded_msgs = torch.cat((tokenized_msgs[:,:1], encoded_msgs), dim=1)

  return encoded_msgs, logits_arr # Logits for debugging

def trans_decode(encoded_msgs, vocab_size):
  with torch.no_grad():
    # The first value in the encoded message 
    # is the first token of the original message
    first_tokens = encoded_msgs[:, :1]
    encoded_msgs = encoded_msgs[:,1:]


    batch_size, msg_len = encoded_msgs.size()
    logits_arr = torch.zeros(batch_size, msg_len, vocab_size).to('cuda') # For debugging
    decoded_msgs = first_tokens
    for i in range(msg_len):
      logits = model(decoded_msgs).logits
      logits_arr[:,i] = logits[:,i] # For debugging
      _, indices = torch.sort(logits[:,i,:], dim=1, descending=True, stable=True)
      decoded_tokens = torch.gather(indices, 1, encoded_msgs[:,i:i+1])
      decoded_msgs = torch.cat((decoded_msgs, decoded_tokens), dim=1)
  return decoded_msgs, logits_arr # Logits for debugging

def verify_msgs(decoded_msgs, original_msgs, attentions):
  attentions_bool_mask = attentions.type(torch.BoolTensor).to('cuda')
  pad_token_mask = torch.ones(decoded_msgs.size(), dtype=int).to('cuda') * PAD_TOKEN
  # We do this masking because the decompressor will spit out garbage output
  # after the end of a message but we don't care about this because we can identify
  # end-of-message by looking for the first padding token.
  decoded_msgs_cleaned = torch.where(attentions_bool_mask, decoded_msgs, pad_token_mask)
  return torch.all(decoded_msgs_cleaned == original_msgs)

In [15]:
# Evaluation functions
import numpy as np
import os
def get_compressed_size(encoded_msgs, attentions, messages):
  batch_size = encoded_msgs.size()[0]
  sizes = []
  for i in range(batch_size):
    attention = attentions[i].tolist()
    end = -1
    # Get end of massage by looking at the attentions
    try:
      end = attentions[i].tolist().index(0)
    except ValueError:
      pass

    if end == -1: # Entire encoding represents a message. No padding
      encoding = encoded_msgs[i].tolist()
    else:
      encoding = encoded_msgs[i][:end+1].tolist()
    # print("trans encoding (indices): ", encoding)

    # VARINT ENCODING
    binary_arr = varint_encode(encoding)
    varint_encoding_size = len(binary_arr)

    # ARITH ENCODING
    encoding = str(encoding).replace(" ", "")
    with open('input.txt', "w") as f:
      f.write(encoding)
    # Perform arithmetic encoding file compression
    with open('input.txt', "rb") as inp, \
        contextlib.closing(BitOutputStream(open("output.txt", "wb"))) as bitout:
      compress(inp, bitout)
    arith_encoding_size = os.path.getsize('output.txt')
    orig_msg_bytes = bytes(messages[i], 'utf-8')
    og_size = len(orig_msg_bytes)
    gzip_encoding_size = len(gzip.compress(orig_msg_bytes, compresslevel=9))
    temp = (og_size, gzip_encoding_size, varint_encoding_size, arith_encoding_size)
    sizes.append(temp)
  return sizes
  #   binary_arr = varint_encode(encoding)
  #   trans_encoding_size = len(binary_arr)

  #   orig_msg_bytes = bytes(messages[i], 'utf-8')
  #   gzip_encoding_size = len(gzip.compress(orig_msg_bytes, compresslevel=9))
  #   sizes.append((trans_encoding_size, gzip_encoding_size))
  # return sizes

In [10]:
from datasets import Dataset
import json
import gdown

url = 'https://drive.google.com/uc?id=1sAgDtEj-UjJECfTF6xfiWFk7lrTX7yoV'
filename = "articles_1000.json"
gdown.download(url, filename, quiet=False)
with open(filename, 'r') as f:
    data = json.load(f)
dataset1000 = Dataset.from_dict(data)

Downloading...
From: https://drive.google.com/uc?id=1sAgDtEj-UjJECfTF6xfiWFk7lrTX7yoV
To: /content/articles_1000.json
100%|██████████| 3.77M/3.77M [00:00<00:00, 84.1MB/s]


In [11]:
d1000 = dataset1000.map(
    lambda example: tokenizer(example['article'], return_tensors="np", padding="max_length", truncation=True, max_length=512),
    batched=True,
    batch_size=16
)
d1000.set_format(type='torch', columns=['input_ids', 'attention_mask'])

  0%|          | 0/63 [00:00<?, ?ba/s]

In [13]:
from torch.utils.data import DataLoader
BATCH_SIZE = 8

eval_dataloader_1000 = DataLoader(d1000, batch_size=BATCH_SIZE)
article_dataloader_1000 = DataLoader(d1000['article'], batch_size=BATCH_SIZE)

In [None]:
stats_1000 = []
i = 0
# for batch, articles in zip(eval_dataloader_1000, article_dataloader_1000):
#   compressed_sizes = get_compressed_size(batch['input_ids'], batch['attention_mask'], articles)
#   print(i, compressed_sizes)
#   stats_1000.append(compressed_sizes)
#   i += BATCH_SIZE

for batch, articles in zip(eval_dataloader_1000, article_dataloader_1000):
    encoded_msgs, encoder_logits = trans_encode(batch['input_ids'].to('cuda'), batch['attention_mask'].to('cuda'), VOCAB_SIZE)
    compressed_sizes = get_compressed_size(encoded_msgs, batch['attention_mask'], articles)
    print(i, compressed_sizes)
    stats_1000.append(compressed_sizes)
    i += BATCH_SIZE

0 [(2772, 1505, 564, 547), (3157, 1537, 545, 516), (5709, 2680, 543, 511), (2581, 1331, 535, 502), (4770, 2247, 539, 509), (3017, 1527, 551, 526), (3453, 1684, 545, 545), (5449, 2576, 536, 511)]
8 [(941, 555, 221, 238), (7322, 3211, 558, 526), (3807, 1800, 552, 523), (4697, 2196, 546, 539), (1626, 849, 346, 346), (3484, 1390, 533, 484), (10349, 4637, 541, 510), (10310, 4571, 561, 540)]
16 [(4686, 2228, 555, 543), (10333, 4121, 555, 514), (2046, 1055, 444, 446), (8631, 3936, 555, 531), (1899, 958, 409, 405), (4720, 2386, 563, 562), (1658, 878, 338, 348), (6426, 3029, 537, 496)]
24 [(3325, 1609, 562, 539), (1484, 789, 341, 353), (2965, 1453, 531, 500), (3361, 1671, 549, 547), (4431, 2205, 577, 562), (5827, 2757, 529, 502), (2658, 1300, 546, 531), (5528, 2601, 554, 552)]


In [None]:
import json

with open('stats1000.json', 'w') as stats_file:
    json.dump(stats_1000, stats_file)

Old

In [10]:
# from datasets import load_dataset, DatasetDict
# news_dataset = load_dataset("ccdv/cnn_dailymail", "3.0.0")
# # Remove unneeded columns, just keep "article"
# news_dataset = news_dataset.remove_columns("id")
# news_dataset = news_dataset.remove_columns("highlights")

Downloading:   0%|          | 0.00/9.27k [00:00<?, ?B/s]

Downloading and preparing dataset cnn_dailymail/3.0.0 to /root/.cache/huggingface/datasets/ccdv___cnn_dailymail/3.0.0/3.0.0/0107f7388b5c6fae455a5661bcd134fc22da53ea75852027040d8d1e997f101f...


  0%|          | 0/5 [00:00<?, ?it/s]

Downloading:   0%|          | 0.00/159M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/376M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/572k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/12.3M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/661k [00:00<?, ?B/s]

  0%|          | 0/5 [00:00<?, ?it/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset cnn_dailymail downloaded and prepared to /root/.cache/huggingface/datasets/ccdv___cnn_dailymail/3.0.0/3.0.0/0107f7388b5c6fae455a5661bcd134fc22da53ea75852027040d8d1e997f101f. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [11]:
# dev_set = news_dataset['validation'].select(range(24))
# tokenized_dev_set = dev_set.map(
#     lambda example: tokenizer(example['article'], return_tensors="np", padding='max_length', truncation=True, max_length=512),
#     batched=True,
#     batch_size=16
# )
# tokenized_dev_set.set_format(type='torch', columns=['input_ids', 'attention_mask'])

# # Filter to only articles whose tokenized lenght is less than 1024
# # because we want to be able to compare the gzipped full article
# # to the neurally compressed full article
# # Compressing truncated articles would mean that the gzipped version
# # would compress the untruncated article but the neural compressor
# # would only compress the truncated version :(
# tokenized_dev_set = tokenized_dev_set.filter(
#     lambda example: not torch.all(example['attention_mask'] == 1)
# )
# tokenized_dev_set

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

Dataset({
    features: ['article', 'input_ids', 'attention_mask'],
    num_rows: 9
})

In [12]:
# from torch.utils.data import DataLoader
# BATCH_SIZE = 8
# eval_dataloader = DataLoader(tokenized_dev_set, batch_size=BATCH_SIZE)
# article_dataloader = DataLoader(tokenized_dev_set['article'], batch_size=BATCH_SIZE)

In [None]:
# torch.cuda.empty_cache()

In [13]:
# import json
# import os
# i = 0
# stats = []
# for batch, articles in zip(eval_dataloader, article_dataloader):
#     encoded_msgs, encoder_logits = trans_encode(batch['input_ids'].to('cuda'), batch['attention_mask'].to('cuda'), VOCAB_SIZE)
#     compressed_sizes = get_compressed_size(encoded_msgs, batch['attention_mask'], articles)
#     print(i, compressed_sizes)
#     stats.append(compressed_sizes)
#     with open('stats.json', 'w') as stats_file:
#         json.dump(stats, stats_file)
#     i += BATCH_SIZE
# # batch_size = 1 -> 6 seconds
# # batch_size  = 2 -> 10 seconds
# # batch_size = 4 -> 16 seconds
# # batch_Size = 8 -> 30 seconds

encoding:  [7,2617,0,1624,33,1,0,0,18,9,40,1,4381,4,1,1,0,4,7,0,8,11,1,0,1,8,0,5,159,116,0,0,0,0,0,1,3,0,3,0,3,0,0,14,32,13,0,0,0,2,1,4,0,0,2,0,7,0,0,105,0,0,55,286,150,5,0,1,1805,0,0,2,2,2,4,1,0,2,10,1,2,1,2,9,0,0,0,0,1,2,0,0,1,0,0,0,0,0,16,1,1,0,2,0,0,0,1,0,5,44,0,0,1,120,1,2,0,9,0,5,6,5,2,6,0,0,2,4,0,0,4,0,41,1,0,0,44,1,0,218,0,0,23,0,0,0,13,0,0,23,0,0,0,0,11,0,0,4,2,0,0,0,0,0,1,410,0,0,4,11,0,0,0,0,1,0,0,0,0,205,13,7,2,2,0,0,0,0,199,0,0,0,12,5,2,0,2,1,0,718,714,1,0,0,1164,0,17,996,3,0,0,713,0,0,683,1,1,2,2,0,0,0,0,0,7,7,0,0,47,1,102,4,1,39,0,32,5,13,7,90,2,0,109,0,0,0,1,111,0,1,19,4,1,45,0,0,0,0,0,0,39,550,0,0,24,0,0,0,0,27,2,13,0,3,2,1,0,0,3]
encoding:  [7,2617,0,1624,254,19,22,0,1,98,2,0,2,5,0,614,5,329,51,0,85,1,5,1,0,100,18,1,62,0,54,1,25,2,2,169,0,35,0,1,0,0,24,4,5,0,0,1,0,0,7,3,357,0,1,0,2,960,455,0,5,19,4,8,0,0,0,0,0,1,0,0,1,0,0,1,1,0,11,0,2,1,75,16,0,0,30,0,10,0,0,0,5,2,207,0,1,0,43,0,0,1,3,0,0,6,1,0,16,0,5,2,0,0,3,2,67,0,9,0,1,51,0,76,1285,36,1,0,7,0,20,1,3,79,2,0,1,2,0,1,

In [14]:
# from datasets import Dataset
# import json
# import gdown

# url = 'https://drive.google.com/uc?id=1sAgDtEj-UjJECfTF6xfiWFk7lrTX7yoV'
# filename = "articles_1000.json"
# gdown.download(url, filename, quiet=False)
# with open(filename, 'r') as f:
#     data = json.load(f)
# dataset1000 = Dataset.from_dict(data)
# dataset300 = dataset1000.select(range(300))
# dataset500 = dataset1000.select(range(500))

Downloading...
From: https://drive.google.com/uc?id=1sAgDtEj-UjJECfTF6xfiWFk7lrTX7yoV
To: /content/articles_1000.json
100%|██████████| 3.77M/3.77M [00:00<00:00, 196MB/s]


In [15]:
# dataset300

Dataset({
    features: ['article'],
    num_rows: 300
})

In [None]:
# import gzip

In [None]:
# %%timeit
# for article in tokenized_dev_set['article']:
#   orig_msg_bytes = bytes(article, 'utf-8')
#   gzip.compress(orig_msg_bytes, compresslevel=9)


Dataset({
    features: ['article', 'input_ids', 'attention_mask'],
    num_rows: 301
})

In [None]:
# from torch.utils.data import DataLoader

# eval_dataloader = DataLoader(tokenized_dev_set, batch_size=2)


In [None]:
# for batch in eval_dataloader:
#   encoded_msgs, encoder_logits = trans_encode(batch['input_ids'].to('cuda'), batch['attention_mask'].to('cuda'), VOCAB_SIZE)
#   break


In [None]:
# decoded_msgs, logits_arr = trans_decode(encoded_msgs, VOCAB_SIZE)


RuntimeError: ignored

In [None]:
# verify_msgs(decoded_msgs, batch['input_ids'].to('cuda'), batch['attention_mask'].to('cuda'))

tensor(True, device='cuda:0')

In [None]:
# messages = [" But if you are preparing data and doing cat in each iteration, it gets really slow when the tensor you are generating gets very large. My solution was to cat into", 
#             "msg 2 baby", 
#             "The Boat Race 2021 comprised two side-by-side rowing races that took place on 4 April. The Boat Race is contested annually between crews from the universities of Oxford and Cambridge. Traditionally held on the Championship Course in London, the 2021 race instead took place on the River Great Ouse near Ely (course map pictured). This was the 75th women's race and the 166th men's race;"]

# tokenized = tokenizer(messages, return_tensors="pt", padding="longest", truncation=True, max_length=1024)
# attentions = tokenized.attention_mask.to('cuda')
# # sample_inputs = tokenized.input_ids.to('cuda')

In [None]:
# encoded_msgs, encoder_logits = trans_encode(sample_inputs, attentions, VOCAB_SIZE)
# decoded_msgs, decoder_logits = trans_decode(encoded_msgs, VOCAB_SIZE)
# verify_msgs(decoded_msgs, sample_inputs, attentions).item()

True