In [None]:
!pip install transformers
!pip install datasets



# Arithmetic Coding (adapted from Project Nayuki) 

In [None]:
# 
# Reference arithmetic coding
# Copyright (c) Project Nayuki
# 
# https://www.nayuki.io/page/reference-arithmetic-coding
# https://github.com/nayuki/Reference-arithmetic-coding
# 


# ---- Arithmetic coding core classes ----

# Provides the state and behaviors that arithmetic coding encoders and decoders share.
class ArithmeticCoderBase:
	
	# Constructs an arithmetic coder, which initializes the code range.
	def __init__(self, numbits):
		if numbits < 1:
			raise ValueError("State size out of range")
		
		# -- Configuration fields --
		# Number of bits for the 'low' and 'high' state variables. Must be at least 1.
		# - Larger values are generally better - they allow a larger maximum frequency total (maximum_total),
		#   and they reduce the approximation error inherent in adapting fractions to integers;
		#   both effects reduce the data encoding loss and asymptotically approach the efficiency
		#   of arithmetic coding using exact fractions.
		# - But larger state sizes increase the computation time for integer arithmetic,
		#   and compression gains beyond ~30 bits essentially zero in real-world applications.
		# - Python has native bigint arithmetic, so there is no upper limit to the state size.
		#   For Java and C++ where using native machine-sized integers makes the most sense,
		#   they have a recommended value of num_state_bits=32 as the most versatile setting.
		self.num_state_bits = numbits
		# Maximum range (high+1-low) during coding (trivial), which is 2^num_state_bits = 1000...000.
		self.full_range = 1 << self.num_state_bits
		# The top bit at width num_state_bits, which is 0100...000.
		self.half_range = self.full_range >> 1  # Non-zero
		# The second highest bit at width num_state_bits, which is 0010...000. This is zero when num_state_bits=1.
		self.quarter_range = self.half_range >> 1  # Can be zero
		# Minimum range (high+1-low) during coding (non-trivial), which is 0010...010.
		self.minimum_range = self.quarter_range + 2  # At least 2
		# Maximum allowed total from a frequency table at all times during coding. This differs from Java
		# and C++ because Python's native bigint avoids constraining the size of intermediate computations.
		self.maximum_total = self.minimum_range
		# Bit mask of num_state_bits ones, which is 0111...111.
		self.state_mask = self.full_range - 1
		
		# -- State fields --
		# Low end of this arithmetic coder's current range. Conceptually has an infinite number of trailing 0s.
		self.low = 0
		# High end of this arithmetic coder's current range. Conceptually has an infinite number of trailing 1s.
		self.high = self.state_mask
	
	
	# Updates the code range (low and high) of this arithmetic coder as a result
	# of processing the given symbol with the given frequency table.
	# Invariants that are true before and after encoding/decoding each symbol
	# (letting full_range = 2^num_state_bits):
	# - 0 <= low <= code <= high < full_range. ('code' exists only in the decoder.)
	#   Therefore these variables are unsigned integers of num_state_bits bits.
	# - low < 1/2 * full_range <= high.
	#   In other words, they are in different halves of the full range.
	# - (low < 1/4 * full_range) || (high >= 3/4 * full_range).
	#   In other words, they are not both in the middle two quarters.
	# - Let range = high - low + 1, then full_range/4 < minimum_range
	#   <= range <= full_range. These invariants for 'range' essentially
	#   dictate the maximum total that the incoming frequency table can have.
	def update(self, freqs, symbol):
		# State check
		low = self.low
		high = self.high
		if low >= high or (low & self.state_mask) != low or (high & self.state_mask) != high:
			raise AssertionError("Low or high out of range")
		range = high - low + 1
		if not (self.minimum_range <= range <= self.full_range):
			raise AssertionError("Range out of range")
		
		# Frequency table values check
		total = freqs.get_total()
		symlow = freqs.get_low(symbol)
		symhigh = freqs.get_high(symbol)
		if symlow == symhigh:
			raise ValueError("Symbol has zero frequency")
		if total > self.maximum_total:
			raise ValueError("Cannot code symbol because total is too large")
		
		# Update range
		newlow  = low + symlow  * range // total
		newhigh = low + symhigh * range // total - 1
		self.low = newlow
		self.high = newhigh
		
		# While low and high have the same top bit value, shift them out
		while ((self.low ^ self.high) & self.half_range) == 0:
			self.shift()
			self.low  = ((self.low  << 1) & self.state_mask)
			self.high = ((self.high << 1) & self.state_mask) | 1
		# Now low's top bit must be 0 and high's top bit must be 1
		
		# While low's top two bits are 01 and high's are 10, delete the second highest bit of both
		while (self.low & ~self.high & self.quarter_range) != 0:
			self.underflow()
			self.low = (self.low << 1) ^ self.half_range
			self.high = ((self.high ^ self.half_range) << 1) | self.half_range | 1
	
	
	# Called to handle the situation when the top bit of 'low' and 'high' are equal.
	def shift(self):
		raise NotImplementedError()
	
	
	# Called to handle the situation when low=01(...) and high=10(...).
	def underflow(self):
		raise NotImplementedError()



# Encodes symbols and writes to an arithmetic-coded bit stream.
class ArithmeticEncoder(ArithmeticCoderBase):
	
	# Constructs an arithmetic coding encoder based on the given bit output stream.
	def __init__(self, numbits, bitout):
		super(ArithmeticEncoder, self).__init__(numbits)
		# The underlying bit output stream.
		self.output = bitout
		# Number of saved underflow bits. This value can grow without bound.
		self.num_underflow = 0
	
	
	# Encodes the given symbol based on the given frequency table.
	# This updates this arithmetic coder's state and may write out some bits.
	def write(self, freqs, symbol):
		if not isinstance(freqs, CheckedFrequencyTable):
			freqs = CheckedFrequencyTable(freqs)
		self.update(freqs, symbol)
	
	
	# Terminates the arithmetic coding by flushing any buffered bits, so that the output can be decoded properly.
	# It is important that this method must be called at the end of the each encoding process.
	# Note that this method merely writes data to the underlying output stream but does not close it.
	def finish(self):
		self.output.write(1)
	
	
	def shift(self):
		bit = self.low >> (self.num_state_bits - 1)
		self.output.write(bit)
		
		# Write out the saved underflow bits
		for _ in range(self.num_underflow):
			self.output.write(bit ^ 1)
		self.num_underflow = 0
	
	
	def underflow(self):
		self.num_underflow += 1



# Reads from an arithmetic-coded bit stream and decodes symbols.
class ArithmeticDecoder(ArithmeticCoderBase):
	
	# Constructs an arithmetic coding decoder based on the
	# given bit input stream, and fills the code bits.
	def __init__(self, numbits, bitin):
		super(ArithmeticDecoder, self).__init__(numbits)
		# The underlying bit input stream.
		self.input = bitin
		# The current raw code bits being buffered, which is always in the range [low, high].
		self.code = 0
		for _ in range(self.num_state_bits):
			self.code = self.code << 1 | self.read_code_bit()
	
	
	# Decodes the next symbol based on the given frequency table and returns it.
	# Also updates this arithmetic coder's state and may read in some bits.
	def read(self, freqs):
		if not isinstance(freqs, CheckedFrequencyTable):
			freqs = CheckedFrequencyTable(freqs)
		
		# Translate from coding range scale to frequency table scale
		total = freqs.get_total()
		if total > self.maximum_total:
			raise ValueError("Cannot decode symbol because total is too large")
		range = self.high - self.low + 1
		offset = self.code - self.low
		value = ((offset + 1) * total - 1) // range
		assert value * range // total <= offset
		assert 0 <= value < total
		
		# A kind of binary search. Find highest symbol such that freqs.get_low(symbol) <= value.
		start = 0
		end = freqs.get_symbol_limit()
		while end - start > 1:
			middle = (start + end) >> 1
			if freqs.get_low(middle) > value:
				end = middle
			else:
				start = middle
		assert start + 1 == end
		
		symbol = start
		assert freqs.get_low(symbol) * range // total <= offset < freqs.get_high(symbol) * range // total
		self.update(freqs, symbol)
		if not (self.low <= self.code <= self.high):
			raise AssertionError("Code out of range")
		return symbol
	
	
	def shift(self):
		self.code = ((self.code << 1) & self.state_mask) | self.read_code_bit()
	
	
	def underflow(self):
		self.code = (self.code & self.half_range) | ((self.code << 1) & (self.state_mask >> 1)) | self.read_code_bit()
	
	
	# Returns the next bit (0 or 1) from the input stream. The end
	# of stream is treated as an infinite number of trailing zeros.
	def read_code_bit(self):
		temp = self.input.read()
		if temp == -1:
			temp = 0
		return temp



# ---- Frequency table classes ----

# A table of symbol frequencies. The table holds data for symbols numbered from 0
# to get_symbol_limit()-1. Each symbol has a frequency, which is a non-negative integer.
# Frequency table objects are primarily used for getting cumulative symbol
# frequencies. These objects can be mutable depending on the implementation.
class FrequencyTable:
	
	# Returns the number of symbols in this frequency table, which is a positive number.
	def get_symbol_limit(self):
		raise NotImplementedError()
	
	# Returns the frequency of the given symbol. The returned value is at least 0.
	def get(self, symbol):
		raise NotImplementedError()
	
	# Sets the frequency of the given symbol to the given value.
	# The frequency value must be at least 0.
	def set(self, symbol, freq):
		raise NotImplementedError()
	
	# Increments the frequency of the given symbol.
	def increment(self, symbol):
		raise NotImplementedError()
	
	# Returns the total of all symbol frequencies. The returned value is at
	# least 0 and is always equal to get_high(get_symbol_limit() - 1).
	def get_total(self):
		raise NotImplementedError()
	
	# Returns the sum of the frequencies of all the symbols strictly
	# below the given symbol value. The returned value is at least 0.
	def get_low(self, symbol):
		raise NotImplementedError()
	
	# Returns the sum of the frequencies of the given symbol
	# and all the symbols below. The returned value is at least 0.
	def get_high(self, symbol):
		raise NotImplementedError()



# An immutable frequency table where every symbol has the same frequency of 1.
# Useful as a fallback model when no statistics are available.
class FlatFrequencyTable(FrequencyTable):
	
	# Constructs a flat frequency table with the given number of symbols.
	def __init__(self, numsyms):
		if numsyms < 1:
			raise ValueError("Number of symbols must be positive")
		self.numsymbols = numsyms  # Total number of symbols, which is at least 1
	
	# Returns the number of symbols in this table, which is at least 1.
	def get_symbol_limit(self):
		return self.numsymbols
	
	# Returns the frequency of the given symbol, which is always 1.
	def get(self, symbol):
		self._check_symbol(symbol)
		return 1
	
	# Returns the total of all symbol frequencies, which is
	# always equal to the number of symbols in this table.
	def get_total(self):
		return self.numsymbols
	
	# Returns the sum of the frequencies of all the symbols strictly below
	# the given symbol value. The returned value is equal to 'symbol'.
	def get_low(self, symbol):
		self._check_symbol(symbol)
		return symbol
	
	
	# Returns the sum of the frequencies of the given symbol and all
	# the symbols below. The returned value is equal to 'symbol' + 1.
	def get_high(self, symbol):
		self._check_symbol(symbol)
		return symbol + 1
	
	
	# Returns silently if 0 <= symbol < numsymbols, otherwise raises an exception.
	def _check_symbol(self, symbol):
		if 0 <= symbol < self.numsymbols:
			return
		else:
			raise ValueError("Symbol out of range")
	
	# Returns a string representation of this frequency table. The format is subject to change.
	def __str__(self):
		return "FlatFrequencyTable={}".format(self.numsymbols)
	
	# Unsupported operation, because this frequency table is immutable.
	def set(self, symbol, freq):
		raise NotImplementedError()
	
	# Unsupported operation, because this frequency table is immutable.
	def increment(self, symbol):
		raise NotImplementedError()



# A mutable table of symbol frequencies. The number of symbols cannot be changed
# after construction. The current algorithm for calculating cumulative frequencies
# takes linear time, but there exist faster algorithms such as Fenwick trees.
class SimpleFrequencyTable(FrequencyTable):
	
	# Constructs a simple frequency table in one of two ways:
	# - SimpleFrequencyTable(sequence):
	#   Builds a frequency table from the given sequence of symbol frequencies.
	#   There must be at least 1 symbol, and no symbol has a negative frequency.
	# - SimpleFrequencyTable(freqtable):
	#   Builds a frequency table by copying the given frequency table.
	def __init__(self, freqs):
		if isinstance(freqs, FrequencyTable):
			numsym = freqs.get_symbol_limit()
			self.frequencies = [freqs.get(i) for i in range(numsym)]
		else:  # Assume it is a sequence type
			self.frequencies = list(freqs)  # Make copy
		
		# 'frequencies' is a list of the frequency for each symbol.
		# Its length is at least 1, and each element is non-negative.
		if len(self.frequencies) < 1:
			raise ValueError("At least 1 symbol needed")
		for freq in self.frequencies:
			if freq < 0:
				raise ValueError("Negative frequency")
		
		# Always equal to the sum of 'frequencies'
		self.total = sum(self.frequencies)
		
		# cumulative[i] is the sum of 'frequencies' from 0 (inclusive) to i (exclusive).
		# Initialized lazily. When it is not None, the data is valid.
		self.cumulative = None
	
	
	# Returns the number of symbols in this frequency table, which is at least 1.
	def get_symbol_limit(self):
		return len(self.frequencies)
	
	
	# Returns the frequency of the given symbol. The returned value is at least 0.
	def get(self, symbol):
		self._check_symbol(symbol)
		return self.frequencies[symbol]
	
	
	# Sets the frequency of the given symbol to the given value. The frequency value
	# must be at least 0. If an exception is raised, then the state is left unchanged.
	def set(self, symbol, freq):
		self._check_symbol(symbol)
		if freq < 0:
			raise ValueError("Negative frequency")
		temp = self.total - self.frequencies[symbol]
		assert temp >= 0
		self.total = temp + freq
		self.frequencies[symbol] = freq
		self.cumulative = None
	
	
	# Increments the frequency of the given symbol.
	def increment(self, symbol):
		self._check_symbol(symbol)
		self.total += 1
		self.frequencies[symbol] += 1
		self.cumulative = None
	
	
	# Returns the total of all symbol frequencies. The returned value is at
	# least 0 and is always equal to get_high(get_symbol_limit() - 1).
	def get_total(self):
		return self.total
	
	
	# Returns the sum of the frequencies of all the symbols strictly
	# below the given symbol value. The returned value is at least 0.
	def get_low(self, symbol):
		self._check_symbol(symbol)
		if self.cumulative is None:
			self._init_cumulative()
		return self.cumulative[symbol]
	
	
	# Returns the sum of the frequencies of the given symbol
	# and all the symbols below. The returned value is at least 0.
	def get_high(self, symbol):
		self._check_symbol(symbol)
		if self.cumulative is None:
			self._init_cumulative()
		return self.cumulative[symbol + 1]
	
	
	# Recomputes the array of cumulative symbol frequencies.
	def _init_cumulative(self):
		cumul = [0]
		sum = 0
		for freq in self.frequencies:
			sum += freq
			cumul.append(sum)
		assert sum == self.total
		self.cumulative = cumul
	
	
	# Returns silently if 0 <= symbol < len(frequencies), otherwise raises an exception.
	def _check_symbol(self, symbol):
		if 0 <= symbol < len(self.frequencies):
			return
		else:
			raise ValueError("Symbol out of range")
	
	
	# Returns a string representation of this frequency table,
	# useful for debugging only, and the format is subject to change.
	def __str__(self):
		result = ""
		for (i, freq) in enumerate(self.frequencies):
			result += "{}\t{}\n".format(i, freq)
		return result



# A wrapper that checks the preconditions (arguments) and postconditions (return value) of all
# the frequency table methods. Useful for finding faults in a frequency table implementation.
class CheckedFrequencyTable(FrequencyTable):
	
	def __init__(self, freqtab):
		# The underlying frequency table that holds the data
		self.freqtable = freqtab
	
	
	def get_symbol_limit(self):
		result = self.freqtable.get_symbol_limit()
		if result <= 0:
			raise AssertionError("Non-positive symbol limit")
		return result
	
	
	def get(self, symbol):
		result = self.freqtable.get(symbol)
		if not self._is_symbol_in_range(symbol):
			raise AssertionError("ValueError expected")
		if result < 0:
			raise AssertionError("Negative symbol frequency")
		return result
	
	
	def get_total(self):
		result = self.freqtable.get_total()
		if result < 0:
			raise AssertionError("Negative total frequency")
		return result
	
	
	def get_low(self, symbol):
		if self._is_symbol_in_range(symbol):
			low   = self.freqtable.get_low (symbol)
			high  = self.freqtable.get_high(symbol)
			if not (0 <= low <= high <= self.freqtable.get_total()):
				raise AssertionError("Symbol low cumulative frequency out of range")
			return low
		else:
			self.freqtable.get_low(symbol)
			raise AssertionError("ValueError expected")
	
	
	def get_high(self, symbol):
		if self._is_symbol_in_range(symbol):
			low   = self.freqtable.get_low (symbol)
			high  = self.freqtable.get_high(symbol)
			if not (0 <= low <= high <= self.freqtable.get_total()):
				raise AssertionError("Symbol high cumulative frequency out of range")
			return high
		else:
			self.freqtable.get_high(symbol)
			raise AssertionError("ValueError expected")
	
	
	def __str__(self):
		return "CheckedFrequencyTable (" + str(self.freqtable) + ")"
	
	
	def set(self, symbol, freq):
		self.freqtable.set(symbol, freq)
		if not self._is_symbol_in_range(symbol) or freq < 0:
			raise AssertionError("ValueError expected")
	
	
	def increment(self, symbol):
		self.freqtable.increment(symbol)
		if not self._is_symbol_in_range(symbol):
			raise AssertionError("ValueError expected")
	
	
	def _is_symbol_in_range(self, symbol):
		return 0 <= symbol < self.get_symbol_limit()



# ---- Bit-oriented I/O streams ----

# A stream of bits that can be read. Because they come from an underlying byte stream,
# the total number of bits is always a multiple of 8. The bits are read in big endian.
class BitInputStream:
	
	# Constructs a bit input stream based on the given byte input stream.
	def __init__(self, inp):
		# The underlying byte stream to read from
		self.input = inp
		# Either in the range [0x00, 0xFF] if bits are available, or -1 if end of stream is reached
		self.currentbyte = 0
		# Number of remaining bits in the current byte, always between 0 and 7 (inclusive)
		self.numbitsremaining = 0
	
	
	# Reads a bit from this stream. Returns 0 or 1 if a bit is available, or -1 if
	# the end of stream is reached. The end of stream always occurs on a byte boundary.
	def read(self):
		if self.currentbyte == -1:
			return -1
		if self.numbitsremaining == 0:
			temp = self.input.read(1)
			if len(temp) == 0:
				self.currentbyte = -1
				return -1
			self.currentbyte = temp[0]
			self.numbitsremaining = 8
		assert self.numbitsremaining > 0
		self.numbitsremaining -= 1
		return (self.currentbyte >> self.numbitsremaining) & 1
	
	
	# Reads a bit from this stream. Returns 0 or 1 if a bit is available, or raises an EOFError
	# if the end of stream is reached. The end of stream always occurs on a byte boundary.
	def read_no_eof(self):
		result = self.read()
		if result != -1:
			return result
		else:
			raise EOFError()
	
	
	# Closes this stream and the underlying input stream.
	def close(self):
		self.input.close()
		self.currentbyte = -1
		self.numbitsremaining = 0



# A stream where bits can be written to. Because they are written to an underlying
# byte stream, the end of the stream is padded with 0's up to a multiple of 8 bits.
# The bits are written in big endian.
class BitOutputStream:
	
	# Constructs a bit output stream based on the given byte output stream.
	def __init__(self, out):
		self.output = out  # The underlying byte stream to write to
		self.currentbyte = 0  # The accumulated bits for the current byte, always in the range [0x00, 0xFF]
		self.numbitsfilled = 0  # Number of accumulated bits in the current byte, always between 0 and 7 (inclusive)
	
	
	# Writes a bit to the stream. The given bit must be 0 or 1.
	def write(self, b):
		if b not in (0, 1):
			raise ValueError("Argument must be 0 or 1")
		self.currentbyte = (self.currentbyte << 1) | b
		self.numbitsfilled += 1
		if self.numbitsfilled == 8:
			towrite = bytes((self.currentbyte,))
			self.output.write(towrite)
			self.currentbyte = 0
			self.numbitsfilled = 0
	
	
	# Closes this stream and the underlying output stream. If called when this
	# bit stream is not at a byte boundary, then the minimum number of "0" bits
	# (between 0 and 7 of them) are written as padding to reach the next byte boundary.
	def close(self):
		while self.numbitsfilled != 0:
			self.write(0)
		self.output.close()


Compress

In [None]:
# 
# Compression application using adaptive arithmetic coding
# 
# Usage: python adaptive-arithmetic-compress.py InputFile OutputFile
# Then use the corresponding adaptive-arithmetic-decompress.py application to recreate the original input file.
# Note that the application starts with a flat frequency table of 257 symbols (all set to a frequency of 1),
# and updates it after each byte encoded. The corresponding decompressor program also starts with a flat
# frequency table and updates it after each byte decoded. It is by design that the compressor and
# decompressor have synchronized states, so that the data can be decompressed properly.
# 
# Copyright (c) Project Nayuki
# 
# https://www.nayuki.io/page/reference-arithmetic-coding
# https://github.com/nayuki/Reference-arithmetic-coding
# 
import contextlib, sys
# import arithmeticcoding


# Command line main application function.
def main(args):
	# Handle command line arguments
	if len(args) != 2:
		sys.exit("Usage: python adaptive-arithmetic-compress.py InputFile OutputFile")
	inputfile, outputfile = args
	
	# Perform file compression
	with open(inputfile, "rb") as inp, \
			contextlib.closing(BitOutputStream(open(outputfile, "wb"))) as bitout:
		compress(inp, bitout)


def compress(inp, bitout):
	initfreqs = FlatFrequencyTable(257)
	freqs = SimpleFrequencyTable(initfreqs)
	enc = ArithmeticEncoder(32, bitout)
	while True:
		# Read and encode one byte
		symbol = inp.read(1)
		if len(symbol) == 0:
			break
		enc.write(freqs, symbol[0])
		freqs.increment(symbol[0])
	enc.write(freqs, 256)  # EOF
	enc.finish()  # Flush remaining code bits


# Main launcher
# if __name__ == "__main__":
# 	main(sys.argv[1 : ])


Decompress

In [None]:
# 
# Decompression application using adaptive arithmetic coding
# 
# Usage: python adaptive-arithmetic-decompress.py InputFile OutputFile
# This decompresses files generated by the adaptive-arithmetic-compress.py application.
# 
# Copyright (c) Project Nayuki
# 
# https://www.nayuki.io/page/reference-arithmetic-coding
# https://github.com/nayuki/Reference-arithmetic-coding
# 

import sys
# import arithmeticcoding


# Command line main application function.
def main(args):
	# Handle command line arguments
	if len(args) != 2:
		sys.exit("Usage: python adaptive-arithmetic-decompress.py InputFile OutputFile")
	inputfile, outputfile = args
	
	# Perform file decompression
	with open(inputfile, "rb") as inp, open(outputfile, "wb") as out:
		bitin = BitInputStream(inp)
		decompress(bitin, out)


def decompress(bitin, out):
	initfreqs = FlatFrequencyTable(257)
	freqs = SimpleFrequencyTable(initfreqs)
	dec = ArithmeticDecoder(32, bitin)
	while True:
		# Decode and write one byte
		symbol = dec.read(freqs)
		if symbol == 256:  # EOF symbol
			break
		out.write(bytes((symbol,)))
		freqs.increment(symbol)


# Main launcher
# if __name__ == "__main__":
# 	main(sys.argv[1 : ])


# Get CNN/DailyMail data ready

In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
import gzip

model = GPT2LMHeadModel.from_pretrained('gpt2').to('cuda')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
_ = model.eval()

Downloading:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/523M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/0.99M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

In [None]:
# from datasets import load_dataset, DatasetDict
# news_dataset = load_dataset("ccdv/cnn_dailymail", "3.0.0")
# # Remove unneeded columns, just keep "article"
# news_dataset = news_dataset.remove_columns("id")
# news_dataset = news_dataset.remove_columns("highlights")

Downloading:   0%|          | 0.00/9.27k [00:00<?, ?B/s]

Downloading and preparing dataset cnn_dailymail/3.0.0 to /root/.cache/huggingface/datasets/ccdv___cnn_dailymail/3.0.0/3.0.0/0107f7388b5c6fae455a5661bcd134fc22da53ea75852027040d8d1e997f101f...


  0%|          | 0/5 [00:00<?, ?it/s]

Downloading:   0%|          | 0.00/159M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/376M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/572k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/12.3M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/661k [00:00<?, ?B/s]

  0%|          | 0/5 [00:00<?, ?it/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset cnn_dailymail downloaded and prepared to /root/.cache/huggingface/datasets/ccdv___cnn_dailymail/3.0.0/3.0.0/0107f7388b5c6fae455a5661bcd134fc22da53ea75852027040d8d1e997f101f. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
from datasets import Dataset
import json
import gdown

url = 'https://drive.google.com/uc?id=1sAgDtEj-UjJECfTF6xfiWFk7lrTX7yoV'
filename = "articles_1000.json"
gdown.download(url, filename, quiet=False)
with open(filename, 'r') as f:
    data = json.load(f)
dataset1000 = Dataset.from_dict(data)
# dataset300 = dataset1000.select(range(300))
# dataset500 = dataset1000.select(range(500))

Downloading...
From: https://drive.google.com/uc?id=1sAgDtEj-UjJECfTF6xfiWFk7lrTX7yoV
To: /content/articles_1000.json
100%|██████████| 3.77M/3.77M [00:00<00:00, 100MB/s]


In [None]:
d300 = dataset300.map(
    lambda example: tokenizer(example['article'], return_tensors="np", padding="max_length", truncation=True, max_length=512),
    batched=True,
    batch_size=16
)
d300.set_format(type='torch', columns=['input_ids', 'attention_mask'])

  0%|          | 0/19 [00:00<?, ?ba/s]

In [None]:
d500 = dataset500.map(
    lambda example: tokenizer(example['article'], return_tensors="np", padding="max_length", truncation=True, max_length=512),
    batched=True,
    batch_size=16
)
d500.set_format(type='torch', columns=['input_ids', 'attention_mask'])

  0%|          | 0/32 [00:00<?, ?ba/s]

In [None]:
d1000 = dataset1000.map(
    lambda example: tokenizer(example['article'], return_tensors="np", padding="max_length", truncation=True, max_length=512),
    batched=True,
    batch_size=16
)
d1000.set_format(type='torch', columns=['input_ids', 'attention_mask'])

  0%|          | 0/63 [00:00<?, ?ba/s]

In [None]:
d1000

Dataset({
    features: ['article', 'input_ids', 'attention_mask'],
    num_rows: 1000
})

In [None]:
# dev_set = news_dataset['validation'].select(range(1000))
# tokenized_dev_set = dev_set.map(
#     lambda example: tokenizer(example['article'], return_tensors="np", padding="max_length", truncation=True, max_length=512),
#     batched=True,
#     batch_size=16
# )
# tokenized_dev_set.set_format(type='torch', columns=['input_ids', 'attention_mask'])

# # Filter to only articles whose tokenized lenght is less than 1024
# # because we want to be able to compare the gzipped full article
# # to the neurally compressed full article
# # Compressing truncated articles would mean that the gzipped version
# # would compress the untruncated article but the neural compressor
# # would only compress the truncated version :(
# tokenized_dev_set = tokenized_dev_set.filter(
#     lambda example: not torch.all(example['attention_mask'] == 1)
# )


  0%|          | 0/63 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import sys
sys.path.append('drive/Project')

In [None]:
# import numpy as np
# np.savetxt('input.txt', temp.numpy(), fmt="%s")
# torch.save(temp, 'tensor.pt')

In [None]:
# with open('input.txt', 'r') as out:
    # arith_encoding_size = len(bytes(out.read(), 'utf-8'))
# print(arith_encoding_size)

2663


In [None]:
from torch.utils.data import DataLoader
BATCH_SIZE = 8

# eval_dataloader_300 = DataLoader(d300, batch_size=BATCH_SIZE)
# article_dataloader_300 = DataLoader(d300['article'], batch_size=BATCH_SIZE)

# eval_dataloader_500 = DataLoader(d500, batch_size=BATCH_SIZE)
# article_dataloader_500 = DataLoader(d500['article'], batch_size=BATCH_SIZE)

eval_dataloader_1000 = DataLoader(d1000, batch_size=BATCH_SIZE)
article_dataloader_1000 = DataLoader(d1000['article'], batch_size=BATCH_SIZE)

In [None]:
# Evaluation functions
import numpy as np 
import gzip
import os
def get_compressed_size(encoded_msgs, attentions, messages):
  batch_size = encoded_msgs.size()[0]
  sizes = []
  for i in range(batch_size):
    attention = attentions[i].tolist()
    end = -1
    # Get end of massage by looking at the attentions
    try:
      end = attentions[i].tolist().index(0)
    except ValueError:
      pass

    if end == -1: # Entire encoding represents a message. No padding
      encoding = encoded_msgs[i].tolist()
    else:
      encoding = encoded_msgs[i][:end+1].tolist()
    
    # print(encoding)
    # write encoding to a file
    # np.savetxt('input.txt', encoding.numpy(), fmt="%s")
    # np.savetxt('input.txt', encoding, fmt="%s", encoding='utf-8')

    encoding = str(encoding).replace(" ", "")
    with open('input.txt', "w") as f:
      f.write(encoding)

    # Perform file compression
    with open('input.txt', "rb") as inp, \
        contextlib.closing(BitOutputStream(open("output.txt", "wb"))) as bitout:
      compress(inp, bitout)

    # Assess how large the compressed rep is in comparison to original message
    # with open('output.txt', 'r') as out:
    #   print("in with")
    #   arith_encoding_size = len(bytes(out.read(), 'latin-1'))

    arith_encoding_size = os.path.getsize('output.txt')
    orig_msg_bytes = bytes(messages[i], 'utf-8')
    gzip_encoding_size = len(gzip.compress(orig_msg_bytes, compresslevel=9))
    sizes.append((len(orig_msg_bytes), arith_encoding_size, gzip_encoding_size))
  return sizes

In [None]:
stats_300 = []
i = 0
for batch, articles in zip(eval_dataloader_300, article_dataloader_300):
  compressed_sizes = get_compressed_size(batch['input_ids'], batch['attention_mask'], articles)
  print(i, compressed_sizes)
  stats_300.append(compressed_sizes)
  i += BATCH_SIZE

encoding:  [7,18474,8,14731,6705,11,1266,1900,329,465,33578,286,275,14739,16570,10018,1073,350,13,1623,2213,1531,319,3195,338,366,464,360,31469,286,367,8101,446,553,3724,3321,706,257,4506,8526,13,679,373,9193,13,6705,3724,287,10496,501,287,42441,652,11,2258,5913,11,286,19481,422,35647,11,531,6542,406,1381,26615,11,257,15076,1545,290,8502,16008,13,4900,339,1549,587,257,8179,8674,329,4647,287,13766,290,287,8502,11,6705,1422,470,1716,5863,1566,13521,11,618,366,464,360,31469,286,367,8101,446,338,1,11676,79,505,41700,2540,307,3723,656,5242,286,1605,5682,2048,790,3217,1755,13,1114,3598,7028,11,6705,338,10018,1073,350,13,1623,2213,1531,26172,262,31093,71,500,12,20270,11083,6510,736,290,6071,1973,262,736,9725,286,46718,367,8101,446,3418,11,7859,11,3584,465,366,8940,14748,1,3221,4444,351,683,21899,465,13969,1097,13,4900,10018,1073,373,3105,12,86,2175,290,10622,11,6705,2921,683,257,1200,2339,17131,326,1392,22051,290,925,683,886,6648,13,2399,2095,2627,1900,329,465,18778,366,365,86,12,365,86,12,36

In [None]:
stats_500 = []
i = 0
for batch, articles in zip(eval_dataloader_500, article_dataloader_500):
  compressed_sizes = get_compressed_size(batch['input_ids'], batch['attention_mask'], articles)
  print(i, compressed_sizes)
  stats_500.append(compressed_sizes)
  i += BATCH_SIZE

0 [(1091, 1505), (1108, 1537), (1094, 2680), (1106, 1331), (1096, 2247), (1068, 1527), (1091, 1684), (1103, 2576)]
8 [(486, 555), (1106, 3211), (1081, 1800), (1122, 2196), (739, 849), (1102, 1390), (1091, 4637), (1091, 4571)]
16 [(1063, 2228), (1095, 4121), (921, 1055), (1086, 3936), (861, 958), (1097, 2386), (723, 878), (1091, 3029)]
24 [(1105, 1609), (704, 789), (1085, 1453), (1057, 1671), (1083, 2205), (1075, 2757), (1098, 1300), (1099, 2601)]
32 [(1087, 2309), (1077, 3607), (399, 456), (1100, 2294), (1098, 3628), (1099, 1997), (1120, 2686), (814, 878)]
40 [(1080, 1825), (1106, 2942), (1093, 1375), (1033, 1187), (1095, 2136), (1098, 3678), (1091, 3826), (1110, 1560)]
48 [(1105, 1385), (947, 984), (841, 978), (1091, 1999), (1098, 2552), (703, 761), (1094, 2797), (1099, 2393)]
56 [(1085, 2742), (897, 1072), (1086, 1287), (1108, 2083), (1102, 2892), (1073, 2322), (887, 1013), (976, 1116)]
64 [(1083, 2576), (1083, 2271), (1092, 1942), (1079, 2438), (1106, 2110), (600, 702), (745, 864), 

In [None]:
stats_1000 = []
i = 0
for batch, articles in zip(eval_dataloader_1000, article_dataloader_1000):
  compressed_sizes = get_compressed_size(batch['input_ids'], batch['attention_mask'], articles)
  # print(i, compressed_sizes)
  stats_1000.append(compressed_sizes)
  i += BATCH_SIZE

In [None]:
print(stats_1000[0])

[(2772, 1091, 1505), (3157, 1108, 1537), (5709, 1094, 2680), (2581, 1106, 1331), (4770, 1096, 2247), (3017, 1068, 1527), (3453, 1091, 1684), (5449, 1103, 2576)]


In [None]:
import json
# with open('stats300.json', 'w') as stats_file:
#     json.dump(stats_300, stats_file)
# with open('stats500.json', 'w') as stats_file:
#     json.dump(stats_500, stats_file)
with open('stats1000.json', 'w') as stats_file:
    json.dump(stats_1000, stats_file)

Old

In [None]:
import gzip

In [None]:
%%timeit
for article in tokenized_dev_set['article']:
  orig_msg_bytes = bytes(article, 'utf-8')
  gzip.compress(orig_msg_bytes, compresslevel=9)


Dataset({
    features: ['article', 'input_ids', 'attention_mask'],
    num_rows: 301
})

In [None]:
from torch.utils.data import DataLoader

eval_dataloader = DataLoader(tokenized_dev_set, batch_size=2)


In [None]:
for batch in eval_dataloader:
  encoded_msgs, encoder_logits = trans_encode(batch['input_ids'].to('cuda'), batch['attention_mask'].to('cuda'), VOCAB_SIZE)
  break


In [None]:
decoded_msgs, logits_arr = trans_decode(encoded_msgs, VOCAB_SIZE)


RuntimeError: ignored

In [None]:
verify_msgs(decoded_msgs, batch['input_ids'].to('cuda'), batch['attention_mask'].to('cuda'))

tensor(True, device='cuda:0')

In [None]:
# messages = [" But if you are preparing data and doing cat in each iteration, it gets really slow when the tensor you are generating gets very large. My solution was to cat into", 
#             "msg 2 baby", 
#             "The Boat Race 2021 comprised two side-by-side rowing races that took place on 4 April. The Boat Race is contested annually between crews from the universities of Oxford and Cambridge. Traditionally held on the Championship Course in London, the 2021 race instead took place on the River Great Ouse near Ely (course map pictured). This was the 75th women's race and the 166th men's race;"]

# tokenized = tokenizer(messages, return_tensors="pt", padding="longest", truncation=True, max_length=1024)
# attentions = tokenized.attention_mask.to('cuda')
# # sample_inputs = tokenized.input_ids.to('cuda')

In [None]:
encoded_msgs, encoder_logits = trans_encode(sample_inputs, attentions, VOCAB_SIZE)
decoded_msgs, decoder_logits = trans_decode(encoded_msgs, VOCAB_SIZE)
verify_msgs(decoded_msgs, sample_inputs, attentions).item()

True