# LLM-архиватор. Эксперименты со steering vector

In [None]:
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm.notebook import tqdm
import time
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import gc

In [None]:
def clear_memory(var_names=None):
    if var_names is None:
        var_names = [
            'model','tokenizer','enc','dec','logits','probs','inp','ctx',
            'ids','recovered_ids','all_bits','all_original_ids'
        ]
    for var in var_names:
        if var in globals():
            del globals()[var]
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()

## Общие и вспомогательные функции кодирования и декодирования

In [None]:
def probs_to_cdf_int(prob, total=1 << 20):
    cdf = np.cumsum(prob, dtype=np.float64)
    cdf_int = np.floor(cdf * total).astype(np.int64)
    cdf_int = np.maximum.accumulate(cdf_int)
    cdf_int = np.concatenate(([0], cdf_int))
    cdf_int[-1] = total
    return cdf_int

In [None]:
class ArithmeticEncoder:
    def __init__(self, precision=32):
        self.precision = precision
        self.half  = 1 << (precision - 1)
        self.quarter = self.half >> 1
        self.mask  = (1 << precision) - 1
        self.low   = 0
        self.high  = self.mask
        self.pending = 0
        self.out = []

    def update(self, cdf_low, cdf_high, total):
        # subdivide interval
        rng = self.high - self.low + 1
        self.high = self.low + (rng * cdf_high) // total - 1
        self.low  = self.low + (rng * cdf_low)  // total

        # renormalize
        while True:
            # E1: MSB equal
            if self.high < self.half:
                self._emit(0)
            elif self.low >= self.half:
                self._emit(1)
                self.low  -= self.half
                self.high -= self.half
            # E3: underflow
            elif self.low >= self.quarter and self.high < 3 * self.quarter:
                self.pending += 1
                self.low  -= self.quarter
                self.high -= self.quarter
            else:
                break

            # shift out
            self.low   = (self.low   << 1) & self.mask
            self.high  = ((self.high << 1) & self.mask) | 1

    def _emit(self, bit):
        self.out.append(bit)
        for _ in range(self.pending):
            self.out.append(1 - bit)
        self.pending = 0

    def finish(self):
        self.pending += 1
        if self.low < self.quarter:
            self._emit(0)
        else:
            self._emit(1)
        for _ in range(self.precision):
            self.out.append((self.low >> (self.precision - 1)) & 1)
            self.low = (self.low << 1) & self.mask

In [None]:
class ArithmeticDecoder:
    def __init__(self, bits, precision=32):
        self.bits = bits
        self.precision = precision
        self.half  = 1 << (precision - 1)
        self.quarter = self.half >> 1
        self.mask  = (1 << precision) - 1

        self.low   = 0
        self.high  = self.mask
        self.value = 0
        self.idx   = 0
        for _ in range(precision):
            self.value = ((self.value << 1) & self.mask) | self._read()

    def _read(self):
        if self.idx < len(self.bits):
            b = self.bits[self.idx]
            self.idx += 1
            return b
        return 0

    def decode(self, cdf_int, total):
        # find symbol
        rng = self.high - self.low + 1
        scaled = ((self.value - self.low + 1) * total - 1) // rng
        symbol = np.searchsorted(cdf_int, scaled, side='right') - 1

        # narrow interval
        c_lo, c_hi = cdf_int[symbol], cdf_int[symbol+1]
        self.high = self.low + (rng * c_hi) // total - 1
        self.low  = self.low + (rng * c_lo) // total

        # renormalize
        while True:
            if self.high < self.half:
                pass
            elif self.low >= self.half:
                self.value -= self.half
                self.low   -= self.half
                self.high  -= self.half
            elif self.low >= self.quarter and self.high < 3 * self.quarter:
                self.value -= self.quarter
                self.low   -= self.quarter
                self.high  -= self.quarter
            else:
                break

            self.low   = (self.low   << 1) & self.mask
            self.high  = ((self.high << 1) & self.mask) | 1
            self.value = ((self.value << 1) & self.mask) | self._read()

        return symbol

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def load_model(model_name: str):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)

    model.eval()
    model.to(device)

    return tokenizer, model

In [None]:
def get_data(size: int, tokenizer, path: str = 'enwik8') -> list[int]:
    with open(path, 'rb') as f:
        data = f.read(size)
    text = data.decode('latin-1')
    return data

In [None]:
def text_to_ids(text: str, tokenizer) -> list[int]:
    return tokenizer.encode(text, add_special_tokens=False)

## Кодирование и декодирование со steering vector.

In [None]:
def encode_chunks_with_steering(
    data: bytes,
    steering_ids: list[int],
    tokenizer,
    model,
    chunk_size: int = 2000,
    total: int = 1 << 30
):
    chunk_data = []
    start_time = time.perf_counter()
    max_pe = model.config.max_position_embeddings
    steering_len = len(steering_ids)

    for start in tqdm(range(0, len(data), chunk_size), desc='Encoding chunks'):
        chunk = data[start:start + chunk_size]
        text  = chunk.decode('latin-1')
        ids   = tokenizer.encode(text, add_special_tokens=False)

        enc = ArithmeticEncoder()
        for i in tqdm(range(1, len(ids)), desc='  Tokens in chunk', leave=False):
            history = ids[max(0, i - (max_pe - steering_len)):i]
            ctx_ids = steering_ids + history
            inp = torch.tensor([ctx_ids], device=device)

            with torch.no_grad():
                logits = model(inp).logits[0, -1]
                probs  = torch.softmax(logits, dim=-1).cpu().numpy()

            cdf = probs_to_cdf_int(probs, total)
            token_id = ids[i]
            enc.update(cdf[token_id], cdf[token_id + 1], total)

        enc.finish()
        chunk_data.append({'ids': ids, 'bits': enc.out})

    encoding_time = time.perf_counter() - start_time
    print(f"Total encoding time: {encoding_time:.2f} seconds")
    return chunk_data

In [None]:
def decode_chunks_with_steering(
    chunk_data,
    steering_ids: list[int],
    model,
    total: int = 1 << 30
):
    start_time = time.perf_counter()
    all_recovered_ids = []

    max_pe = model.config.max_position_embeddings
    steering_len = len(steering_ids)

    for entry in tqdm(chunk_data, desc='Decoding chunks'):
        ids  = entry['ids']
        bits = entry['bits']
        dec  = ArithmeticDecoder(bits)

        rec = [ids[0]]
        for i in tqdm(range(1, len(ids)), desc='  Tokens in chunk', leave=False):
            history = ids[max(0, i - (max_pe - steering_len)):i]
            ctx_ids = steering_ids + history
            inp = torch.tensor([ctx_ids], device=device)

            with torch.no_grad():
                logits = model(inp).logits[0, -1]
                probs  = torch.softmax(logits, dim=-1).cpu().numpy()

            cdf = probs_to_cdf_int(probs, total)
            rec.append(dec.decode(cdf, total))

        all_recovered_ids.extend(rec)

    decoding_time = time.perf_counter() - start_time
    print(f"Total decoding time: {decoding_time:.2f} seconds")
    return all_recovered_ids

In [None]:
def compare_sequences(orig, dec):
    for k, (o, d) in enumerate(zip(orig, dec)):
        if o != d:
            print(f"❌  Расхождение на позиции {k}: orig={o}  decoded={d}")
            return
    if len(orig) != len(dec):
        print(f"❌  Длины списков отличаются: orig={len(orig)}  decoded={len(dec)}")
        return
    print("✅  Совпадают полностью!")

In [None]:
def compression_stats(data: bytes, enc_out: bytes):
    original_bits = len(data) * 8
    compressed_bits = len(enc_out)
    ratio = compressed_bits / original_bits

    print(f"Исходный размер:   {original_bits} бит")
    print(f"Размер после сжатия: {compressed_bits} бит")
    print(f"Коэффициент сжатия: {ratio:.4f}")

In [None]:
def decode_text_from_ids(ids, tokenizer) -> str:
    return tokenizer.decode(
        ids,
        clean_up_tokenization_spaces=False,
        skip_special_tokens=False
    )

## Фиксированная строка как steering vector

In [None]:
tokenizer, model = load_model('EleutherAI/pythia-70m')

In [None]:
data = get_data(50000, tokenizer)

In [None]:
steering_ids = text_to_ids("This is Wikipedia html", tokenizer)

In [None]:
chunk_data = encode_chunks_with_steering(
    data, steering_ids, tokenizer, model)

Encoding chunks:   0%|          | 0/25 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/668 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/732 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/572 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/589 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/577 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/546 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/453 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/505 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/453 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/568 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/500 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/511 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/487 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/456 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/465 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/514 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/426 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/537 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/494 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/600 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/473 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/494 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/495 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/442 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/452 [00:00<?, ?it/s]

Total encoding time: 162.76 seconds


In [None]:
recovered_ids = decode_chunks_with_steering(
    chunk_data, steering_ids, model)

Decoding chunks:   0%|          | 0/25 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/668 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/732 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/572 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/589 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/577 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/546 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/453 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/505 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/453 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/568 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/500 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/511 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/487 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/456 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/465 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/514 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/426 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/537 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/494 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/600 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/473 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/494 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/495 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/442 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/452 [00:00<?, ?it/s]

Total decoding time: 146.35 seconds


In [None]:
orig_ids = sum([entry['ids'] for entry in chunk_data], [])
encoded_output = sum([entry['bits'] for entry in chunk_data], [])

In [None]:
compare_sequences(orig_ids, recovered_ids)

✅  Совпадают полностью!


In [None]:
compression_stats(data, encoded_output)

Исходный размер:   400000 бит
Размер после сжатия: 66406 бит
Коэффициент сжатия: 0.1660


In [None]:
recovered_text = decode_text_from_ids(recovered_ids, tokenizer)

In [None]:
recovered_text

'<mediawiki xmlns="http://www.mediawiki.org/xml/export-0.3/" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://www.mediawiki.org/xml/export-0.3/ http://www.mediawiki.org/xml/export-0.3.xsd" version="0.3" xml:lang="en">\n  <siteinfo>\n    <sitename>Wikipedia</sitename>\n    <base>http://en.wikipedia.org/wiki/Main_Page</base>\n    <generator>MediaWiki 1.6alpha</generator>\n    <case>first-letter</case>\n      <namespaces>\n      <namespace key="-2">Media</namespace>\n      <namespace key="-1">Special</namespace>\n      <namespace key="0" />\n      <namespace key="1">Talk</namespace>\n      <namespace key="2">User</namespace>\n      <namespace key="3">User talk</namespace>\n      <namespace key="4">Wikipedia</namespace>\n      <namespace key="5">Wikipedia talk</namespace>\n      <namespace key="6">Image</namespace>\n      <namespace key="7">Image talk</namespace>\n      <namespace key="8">MediaWiki</namespace>\n      <namespace key="9">MediaWiki talk</namesp

## Статический soft‑prompt

In [None]:
def create_soft_prompt(n_soft: int, model):
    H = model.config.hidden_size
    soft_prompt = nn.Parameter(torch.randn(n_soft, H, device=device))
    return soft_prompt

In [None]:
def encode_chunks_with_soft_prompt_and_embeddings(
    data: bytes,
    tokenizer,
    model,
    soft_prompt: torch.nn.Parameter,
    total: int = 1 << 30,
    chunk_size: int = 2000
):
    chunk_data = []
    start_time = time.perf_counter()
    max_pe = model.config.max_position_embeddings
    n_soft = soft_prompt.shape[0]

    sp = soft_prompt.unsqueeze(0)

    for start in tqdm(range(0, len(data), chunk_size), desc='Encoding chunks'):
        chunk = data[start:start + chunk_size]
        text  = chunk.decode('latin-1')
        ids   = tokenizer.encode(text, add_special_tokens=False)

        enc = ArithmeticEncoder()
        for i in tqdm(range(1, len(ids)), desc='  Tokens in chunk', leave=False):
            history = ids[max(0, i - (max_pe - n_soft)):i]
            input_ids = torch.tensor([history], device=device)
            input_embeds = model.get_input_embeddings()(input_ids)
            inputs_embeds = torch.cat([sp, input_embeds], dim=1)
            attention_mask = torch.ones(inputs_embeds.shape[:2], device=device)

            with torch.no_grad():
                out = model(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
            step_logits = out.logits[0, -1]
            probs = torch.softmax(step_logits, dim=-1).cpu().numpy()
            cdf = probs_to_cdf_int(probs, total)
            token_id = ids[i]
            enc.update(cdf[token_id], cdf[token_id + 1], total)

        enc.finish()
        chunk_data.append({'ids': ids, 'bits': enc.out})

    encoding_time = time.perf_counter() - start_time
    print(f"Total encoding time: {encoding_time:.2f} seconds")
    return chunk_data


In [None]:
def decode_with_soft_prompt(
    chunk_data,
    model,
    soft_prompt: torch.nn.Parameter,
    total: int = 1 << 30
):
    start_time = time.perf_counter()
    all_rec = []
    max_pe = model.config.max_position_embeddings
    n_soft = soft_prompt.shape[0]
    sp = soft_prompt.unsqueeze(0)

    for entry in tqdm(chunk_data, desc="Decoding chunks"):
        ids  = entry['ids']
        bits = entry['bits']
        dec  = ArithmeticDecoder(bits)

        rec = [ids[0].item() if hasattr(ids[0], 'item') else ids[0]]
        for i in tqdm(range(1, len(ids)), desc='  Tokens in chunk', leave=False):
            hist = rec[max(0, i - (max_pe - n_soft)):i]
            input_ids = torch.tensor(hist, dtype=torch.long, device=device).unsqueeze(0)
            embeds = model.get_input_embeddings()(input_ids)
            inp_emb = torch.cat([sp, embeds], dim=1)
            attn_mask = torch.ones(inp_emb.shape[:2], device=device)

            with torch.no_grad():
                out = model(inputs_embeds=inp_emb, attention_mask=attn_mask)
            logits = out.logits[0, -1]
            probs  = torch.softmax(logits, dim=-1).cpu().numpy()
            cdf    = probs_to_cdf_int(probs, total)
            rec.append(dec.decode(cdf, total))

        all_rec.extend(rec)

    decoding_time = time.perf_counter() - start_time
    print(f"Decoding done in {decoding_time:.2f}s")
    return all_rec

In [None]:
soft_prompt = create_soft_prompt(200, model)

In [None]:
chunk_data = encode_chunks_with_soft_prompt_and_embeddings(
    data, tokenizer, model, soft_prompt)

Encoding chunks:   0%|          | 0/25 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/668 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/732 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/572 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/589 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/577 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/546 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/453 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/505 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/453 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/568 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/500 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/511 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/487 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/456 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/465 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/514 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/426 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/537 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/494 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/600 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/473 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/494 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/495 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/442 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/452 [00:00<?, ?it/s]

Total encoding time: 197.45 seconds


In [None]:
all_recovered_ids = decode_with_soft_prompt(
    chunk_data, model, soft_prompt)

Decoding chunks:   0%|          | 0/25 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/668 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/732 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/572 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/589 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/577 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/546 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/453 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/505 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/453 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/568 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/500 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/511 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/487 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/456 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/465 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/514 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/426 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/537 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/494 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/600 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/473 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/494 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/495 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/442 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/452 [00:00<?, ?it/s]

Decoding done in 195.32s


In [None]:
orig_ids = sum([entry['ids'] for entry in chunk_data], [])
encoded_output = sum([entry['bits'] for entry in chunk_data], [])

In [None]:
compare_sequences(orig_ids, recovered_ids)

✅  Совпадают полностью!


In [None]:
compression_stats(data, encoded_output)

Исходный размер:   400000 бит
Размер после сжатия: 68138 бит
Коэффициент сжатия: 0.1703


## Обучаемый soft‑prompt

In [None]:
class TextChunkDataset(Dataset):
    def __init__(self, path, chunk_size, tokenizer, max_bytes=None):
        raw = open(path, 'rb').read()
        if max_bytes:
            raw = raw[:max_bytes]
        self.chunks = []
        for i in range(0, len(raw), chunk_size):
            text = raw[i:i+chunk_size].decode('latin-1', errors='ignore')
            ids  = tokenizer.encode(text, add_special_tokens=False)
            if len(ids) > 1:
                self.chunks.append(ids)

    def __len__(self):
        return len(self.chunks)

    def __getitem__(self, idx):
        return torch.tensor(self.chunks[idx], dtype=torch.long)

In [None]:
def create_data_loader(
    tokenizer,
    chunk_size: int,
    max_bytes: int
):
    dataset = TextChunkDataset('enwik8', chunk_size, tokenizer, max_bytes=max_bytes)
    loader = DataLoader(dataset, batch_size=1, shuffle=True)
    return dataset, loader

In [None]:
def train_soft_prompt(
    loader,
    model,
    soft_prompt: torch.nn.Parameter,
    optimizer,
    epochs: int
):
    n_soft = soft_prompt.size(0)

    for epoch in range(1, epochs + 1):
        total_loss = 0.0
        for batch in tqdm(loader, desc=f"Epoch {epoch}"):
            ids = batch[0].to(device)
            L = ids.size(0)
            if L < 2:
                continue

            optimizer.zero_grad()

            embeds = model.get_input_embeddings()(ids.unsqueeze(0))
            sp     = soft_prompt.unsqueeze(0)
            inp_e  = torch.cat([sp, embeds], dim=1)
            attn   = torch.ones(inp_e.shape[:2], device=device)

            labels = torch.full((1, n_soft + L), -100, dtype=torch.long, device=device)
            labels[0, n_soft:] = ids

            out  = model(inputs_embeds=inp_e, attention_mask=attn, labels=labels)
            loss = out.loss
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg = total_loss / len(loader)
        print(f"Epoch {epoch} done — avg loss: {avg:.6f}")

In [None]:
def encode_with_soft_prompt(
    dataset,
    model,
    soft_prompt: torch.nn.Parameter,
    total: int = 1 << 30
):
    chunk_data = []
    start_time = time.perf_counter()
    max_pe = model.config.max_position_embeddings
    n_soft = soft_prompt.shape[0]

    for ids in tqdm(dataset, desc="Encoding chunks"):
        ids = ids.to(device)
        enc = ArithmeticEncoder()
        sp = soft_prompt.unsqueeze(0)

        for i in tqdm(range(1, ids.size(0)), desc='  Tokens in chunk', leave=False):
            hist = ids[max(0, i - (max_pe - n_soft)):i].unsqueeze(0)
            embeds = model.get_input_embeddings()(hist)
            inp_emb = torch.cat([sp, embeds], dim=1)
            attn_mask = torch.ones(inp_emb.shape[:2], device=device)

            with torch.no_grad():
                out = model(inputs_embeds=inp_emb, attention_mask=attn_mask)
            logits = out.logits[0, -1]
            probs = torch.softmax(logits, dim=-1).cpu().numpy()

            cdf = probs_to_cdf_int(probs, total)
            token_id = ids[i].item()
            enc.update(cdf[token_id], cdf[token_id + 1], total)

        enc.finish()
        chunk_data.append({
            'ids': ids.cpu(),
            'bits': enc.out
        })

    encoding_time = time.perf_counter() - start_time
    print(f"Encoding done in {encoding_time:.2f}s")
    return chunk_data

### 500 эпох обучения

In [None]:
soft_prompt = create_soft_prompt(200, model)

In [None]:
dataset, loader = create_data_loader(tokenizer, chunk_size=2000, max_bytes=50000)

In [None]:
model.train()
optimizer = torch.optim.Adam([soft_prompt], lr=5e-4)

In [None]:
train_soft_prompt(loader, model, soft_prompt, optimizer, 500)

Epoch 1:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 1 done — avg loss: 3.711635


Epoch 2:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 2 done — avg loss: 3.689035


Epoch 3:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 3 done — avg loss: 3.669818


Epoch 4:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 4 done — avg loss: 3.656071


Epoch 5:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 5 done — avg loss: 3.644554


Epoch 6:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 6 done — avg loss: 3.634199


Epoch 7:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 7 done — avg loss: 3.624754


Epoch 8:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 8 done — avg loss: 3.616023


Epoch 9:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 9 done — avg loss: 3.607551


Epoch 10:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 10 done — avg loss: 3.599690


Epoch 11:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 11 done — avg loss: 3.592850


Epoch 12:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 12 done — avg loss: 3.585516


Epoch 13:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 13 done — avg loss: 3.578043


Epoch 14:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 14 done — avg loss: 3.571863


Epoch 15:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 15 done — avg loss: 3.566342


Epoch 16:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 16 done — avg loss: 3.561180


Epoch 17:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 17 done — avg loss: 3.556479


Epoch 18:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 18 done — avg loss: 3.551828


Epoch 19:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 19 done — avg loss: 3.547446


Epoch 20:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 20 done — avg loss: 3.543259


Epoch 21:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 21 done — avg loss: 3.539324


Epoch 22:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 22 done — avg loss: 3.535346


Epoch 23:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 23 done — avg loss: 3.531846


Epoch 24:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 24 done — avg loss: 3.528566


Epoch 25:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 25 done — avg loss: 3.524896


Epoch 26:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 26 done — avg loss: 3.521603


Epoch 27:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 27 done — avg loss: 3.517853


Epoch 28:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 28 done — avg loss: 3.514314


Epoch 29:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 29 done — avg loss: 3.511749


Epoch 30:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 30 done — avg loss: 3.508548


Epoch 31:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 31 done — avg loss: 3.505361


Epoch 32:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 32 done — avg loss: 3.502728


Epoch 33:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 33 done — avg loss: 3.500535


Epoch 34:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 34 done — avg loss: 3.497109


Epoch 35:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 35 done — avg loss: 3.494395


Epoch 36:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 36 done — avg loss: 3.491515


Epoch 37:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 37 done — avg loss: 3.488898


Epoch 38:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 38 done — avg loss: 3.486666


Epoch 39:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 39 done — avg loss: 3.483963


Epoch 40:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 40 done — avg loss: 3.481093


Epoch 41:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 41 done — avg loss: 3.478844


Epoch 42:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 42 done — avg loss: 3.476613


Epoch 43:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 43 done — avg loss: 3.474585


Epoch 44:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 44 done — avg loss: 3.471833


Epoch 45:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 45 done — avg loss: 3.469508


Epoch 46:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 46 done — avg loss: 3.467662


Epoch 47:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 47 done — avg loss: 3.465858


Epoch 48:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 48 done — avg loss: 3.463106


Epoch 49:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 49 done — avg loss: 3.460688


Epoch 50:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 50 done — avg loss: 3.459662


Epoch 51:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 51 done — avg loss: 3.457518


Epoch 52:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 52 done — avg loss: 3.455393


Epoch 53:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 53 done — avg loss: 3.453393


Epoch 54:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 54 done — avg loss: 3.451346


Epoch 55:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 55 done — avg loss: 3.449045


Epoch 56:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 56 done — avg loss: 3.447389


Epoch 57:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 57 done — avg loss: 3.446023


Epoch 58:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 58 done — avg loss: 3.443752


Epoch 59:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 59 done — avg loss: 3.442305


Epoch 60:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 60 done — avg loss: 3.440682


Epoch 61:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 61 done — avg loss: 3.437802


Epoch 62:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 62 done — avg loss: 3.436414


Epoch 63:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 63 done — avg loss: 3.434274


Epoch 64:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 64 done — avg loss: 3.432554


Epoch 65:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 65 done — avg loss: 3.431459


Epoch 66:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 66 done — avg loss: 3.429536


Epoch 67:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 67 done — avg loss: 3.428312


Epoch 68:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 68 done — avg loss: 3.426693


Epoch 69:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 69 done — avg loss: 3.424888


Epoch 70:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 70 done — avg loss: 3.423539


Epoch 71:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 71 done — avg loss: 3.421572


Epoch 72:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 72 done — avg loss: 3.419482


Epoch 73:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 73 done — avg loss: 3.417476


Epoch 74:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 74 done — avg loss: 3.415951


Epoch 75:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 75 done — avg loss: 3.414505


Epoch 76:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 76 done — avg loss: 3.413228


Epoch 77:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 77 done — avg loss: 3.411133


Epoch 78:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 78 done — avg loss: 3.409313


Epoch 79:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 79 done — avg loss: 3.407503


Epoch 80:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 80 done — avg loss: 3.405828


Epoch 81:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 81 done — avg loss: 3.404162


Epoch 82:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 82 done — avg loss: 3.402521


Epoch 83:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 83 done — avg loss: 3.401173


Epoch 84:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 84 done — avg loss: 3.399206


Epoch 85:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 85 done — avg loss: 3.397949


Epoch 86:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 86 done — avg loss: 3.396091


Epoch 87:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 87 done — avg loss: 3.394416


Epoch 88:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 88 done — avg loss: 3.393734


Epoch 89:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 89 done — avg loss: 3.391732


Epoch 90:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 90 done — avg loss: 3.390379


Epoch 91:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 91 done — avg loss: 3.388674


Epoch 92:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 92 done — avg loss: 3.387346


Epoch 93:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 93 done — avg loss: 3.387124


Epoch 94:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 94 done — avg loss: 3.384577


Epoch 95:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 95 done — avg loss: 3.382421


Epoch 96:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 96 done — avg loss: 3.380274


Epoch 97:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 97 done — avg loss: 3.379288


Epoch 98:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 98 done — avg loss: 3.377446


Epoch 99:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 99 done — avg loss: 3.375849


Epoch 100:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 100 done — avg loss: 3.374603


Epoch 101:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 101 done — avg loss: 3.373971


Epoch 102:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 102 done — avg loss: 3.372269


Epoch 103:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 103 done — avg loss: 3.369988


Epoch 104:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 104 done — avg loss: 3.369295


Epoch 105:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 105 done — avg loss: 3.369060


Epoch 106:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 106 done — avg loss: 3.367711


Epoch 107:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 107 done — avg loss: 3.364806


Epoch 108:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 108 done — avg loss: 3.363179


Epoch 109:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 109 done — avg loss: 3.362264


Epoch 110:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 110 done — avg loss: 3.360340


Epoch 111:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 111 done — avg loss: 3.358798


Epoch 112:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 112 done — avg loss: 3.356299


Epoch 113:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 113 done — avg loss: 3.355226


Epoch 114:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 114 done — avg loss: 3.353353


Epoch 115:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 115 done — avg loss: 3.352357


Epoch 116:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 116 done — avg loss: 3.350439


Epoch 117:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 117 done — avg loss: 3.348806


Epoch 118:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 118 done — avg loss: 3.347271


Epoch 119:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 119 done — avg loss: 3.346091


Epoch 120:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 120 done — avg loss: 3.344650


Epoch 121:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 121 done — avg loss: 3.342853


Epoch 122:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 122 done — avg loss: 3.341354


Epoch 123:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 123 done — avg loss: 3.340370


Epoch 124:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 124 done — avg loss: 3.338834


Epoch 125:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 125 done — avg loss: 3.336942


Epoch 126:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 126 done — avg loss: 3.335399


Epoch 127:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 127 done — avg loss: 3.333555


Epoch 128:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 128 done — avg loss: 3.332201


Epoch 129:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 129 done — avg loss: 3.331300


Epoch 130:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 130 done — avg loss: 3.330185


Epoch 131:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 131 done — avg loss: 3.328718


Epoch 132:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 132 done — avg loss: 3.327183


Epoch 133:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 133 done — avg loss: 3.325777


Epoch 134:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 134 done — avg loss: 3.324565


Epoch 135:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 135 done — avg loss: 3.323273


Epoch 136:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 136 done — avg loss: 3.320720


Epoch 137:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 137 done — avg loss: 3.319421


Epoch 138:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 138 done — avg loss: 3.318984


Epoch 139:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 139 done — avg loss: 3.316703


Epoch 140:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 140 done — avg loss: 3.314581


Epoch 141:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 141 done — avg loss: 3.313116


Epoch 142:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 142 done — avg loss: 3.312120


Epoch 143:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 143 done — avg loss: 3.310475


Epoch 144:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 144 done — avg loss: 3.309490


Epoch 145:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 145 done — avg loss: 3.307217


Epoch 146:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 146 done — avg loss: 3.305346


Epoch 147:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 147 done — avg loss: 3.304367


Epoch 148:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 148 done — avg loss: 3.302676


Epoch 149:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 149 done — avg loss: 3.301615


Epoch 150:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 150 done — avg loss: 3.300296


Epoch 151:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 151 done — avg loss: 3.299426


Epoch 152:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 152 done — avg loss: 3.297304


Epoch 153:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 153 done — avg loss: 3.296053


Epoch 154:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 154 done — avg loss: 3.295010


Epoch 155:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 155 done — avg loss: 3.293892


Epoch 156:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 156 done — avg loss: 3.292993


Epoch 157:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 157 done — avg loss: 3.291356


Epoch 158:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 158 done — avg loss: 3.289569


Epoch 159:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 159 done — avg loss: 3.288471


Epoch 160:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 160 done — avg loss: 3.287599


Epoch 161:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 161 done — avg loss: 3.284871


Epoch 162:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 162 done — avg loss: 3.282588


Epoch 163:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 163 done — avg loss: 3.281489


Epoch 164:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 164 done — avg loss: 3.279516


Epoch 165:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 165 done — avg loss: 3.278209


Epoch 166:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 166 done — avg loss: 3.276704


Epoch 167:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 167 done — avg loss: 3.275121


Epoch 168:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 168 done — avg loss: 3.274157


Epoch 169:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 169 done — avg loss: 3.272698


Epoch 170:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 170 done — avg loss: 3.271021


Epoch 171:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 171 done — avg loss: 3.269436


Epoch 172:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 172 done — avg loss: 3.268119


Epoch 173:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 173 done — avg loss: 3.266211


Epoch 174:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 174 done — avg loss: 3.264868


Epoch 175:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 175 done — avg loss: 3.263956


Epoch 176:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 176 done — avg loss: 3.262314


Epoch 177:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 177 done — avg loss: 3.260363


Epoch 178:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 178 done — avg loss: 3.259719


Epoch 179:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 179 done — avg loss: 3.258128


Epoch 180:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 180 done — avg loss: 3.256347


Epoch 181:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 181 done — avg loss: 3.254897


Epoch 182:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 182 done — avg loss: 3.253771


Epoch 183:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 183 done — avg loss: 3.252405


Epoch 184:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 184 done — avg loss: 3.250883


Epoch 185:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 185 done — avg loss: 3.248790


Epoch 186:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 186 done — avg loss: 3.247814


Epoch 187:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 187 done — avg loss: 3.246651


Epoch 188:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 188 done — avg loss: 3.244791


Epoch 189:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 189 done — avg loss: 3.243567


Epoch 190:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 190 done — avg loss: 3.242522


Epoch 191:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 191 done — avg loss: 3.241260


Epoch 192:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 192 done — avg loss: 3.240415


Epoch 193:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 193 done — avg loss: 3.238897


Epoch 194:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 194 done — avg loss: 3.238559


Epoch 195:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 195 done — avg loss: 3.236165


Epoch 196:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 196 done — avg loss: 3.235828


Epoch 197:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 197 done — avg loss: 3.233862


Epoch 198:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 198 done — avg loss: 3.232346


Epoch 199:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 199 done — avg loss: 3.231184


Epoch 200:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 200 done — avg loss: 3.229691


Epoch 201:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 201 done — avg loss: 3.228660


Epoch 202:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 202 done — avg loss: 3.227383


Epoch 203:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 203 done — avg loss: 3.225668


Epoch 204:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 204 done — avg loss: 3.223757


Epoch 205:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 205 done — avg loss: 3.222377


Epoch 206:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 206 done — avg loss: 3.222028


Epoch 207:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 207 done — avg loss: 3.220255


Epoch 208:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 208 done — avg loss: 3.218976


Epoch 209:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 209 done — avg loss: 3.218843


Epoch 210:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 210 done — avg loss: 3.216802


Epoch 211:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 211 done — avg loss: 3.215172


Epoch 212:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 212 done — avg loss: 3.213577


Epoch 213:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 213 done — avg loss: 3.213063


Epoch 214:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 214 done — avg loss: 3.211557


Epoch 215:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 215 done — avg loss: 3.210277


Epoch 216:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 216 done — avg loss: 3.209689


Epoch 217:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 217 done — avg loss: 3.208201


Epoch 218:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 218 done — avg loss: 3.207064


Epoch 219:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 219 done — avg loss: 3.205685


Epoch 220:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 220 done — avg loss: 3.203387


Epoch 221:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 221 done — avg loss: 3.202319


Epoch 222:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 222 done — avg loss: 3.201304


Epoch 223:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 223 done — avg loss: 3.200897


Epoch 224:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 224 done — avg loss: 3.198764


Epoch 225:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 225 done — avg loss: 3.197240


Epoch 226:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 226 done — avg loss: 3.196153


Epoch 227:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 227 done — avg loss: 3.194705


Epoch 228:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 228 done — avg loss: 3.193952


Epoch 229:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 229 done — avg loss: 3.194091


Epoch 230:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 230 done — avg loss: 3.191394


Epoch 231:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 231 done — avg loss: 3.189617


Epoch 232:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 232 done — avg loss: 3.189242


Epoch 233:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 233 done — avg loss: 3.187072


Epoch 234:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 234 done — avg loss: 3.186075


Epoch 235:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 235 done — avg loss: 3.185137


Epoch 236:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 236 done — avg loss: 3.183898


Epoch 237:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 237 done — avg loss: 3.182651


Epoch 238:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 238 done — avg loss: 3.181045


Epoch 239:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 239 done — avg loss: 3.180558


Epoch 240:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 240 done — avg loss: 3.178756


Epoch 241:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 241 done — avg loss: 3.177537


Epoch 242:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 242 done — avg loss: 3.176415


Epoch 243:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 243 done — avg loss: 3.175791


Epoch 244:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 244 done — avg loss: 3.173774


Epoch 245:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 245 done — avg loss: 3.173347


Epoch 246:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 246 done — avg loss: 3.172510


Epoch 247:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 247 done — avg loss: 3.170608


Epoch 248:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 248 done — avg loss: 3.169100


Epoch 249:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 249 done — avg loss: 3.167787


Epoch 250:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 250 done — avg loss: 3.165920


Epoch 251:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 251 done — avg loss: 3.165231


Epoch 252:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 252 done — avg loss: 3.163949


Epoch 253:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 253 done — avg loss: 3.163002


Epoch 254:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 254 done — avg loss: 3.162606


Epoch 255:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 255 done — avg loss: 3.161143


Epoch 256:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 256 done — avg loss: 3.159012


Epoch 257:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 257 done — avg loss: 3.157597


Epoch 258:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 258 done — avg loss: 3.157469


Epoch 259:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 259 done — avg loss: 3.156246


Epoch 260:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 260 done — avg loss: 3.155362


Epoch 261:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 261 done — avg loss: 3.154689


Epoch 262:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 262 done — avg loss: 3.152189


Epoch 263:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 263 done — avg loss: 3.151712


Epoch 264:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 264 done — avg loss: 3.149659


Epoch 265:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 265 done — avg loss: 3.149054


Epoch 266:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 266 done — avg loss: 3.148749


Epoch 267:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 267 done — avg loss: 3.148539


Epoch 268:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 268 done — avg loss: 3.146359


Epoch 269:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 269 done — avg loss: 3.145393


Epoch 270:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 270 done — avg loss: 3.143277


Epoch 271:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 271 done — avg loss: 3.141403


Epoch 272:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 272 done — avg loss: 3.141143


Epoch 273:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 273 done — avg loss: 3.140070


Epoch 274:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 274 done — avg loss: 3.138491


Epoch 275:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 275 done — avg loss: 3.138437


Epoch 276:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 276 done — avg loss: 3.137374


Epoch 277:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 277 done — avg loss: 3.135266


Epoch 278:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 278 done — avg loss: 3.135942


Epoch 279:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 279 done — avg loss: 3.133139


Epoch 280:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 280 done — avg loss: 3.132055


Epoch 281:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 281 done — avg loss: 3.131417


Epoch 282:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 282 done — avg loss: 3.129847


Epoch 283:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 283 done — avg loss: 3.127915


Epoch 284:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 284 done — avg loss: 3.127322


Epoch 285:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 285 done — avg loss: 3.125762


Epoch 286:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 286 done — avg loss: 3.125052


Epoch 287:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 287 done — avg loss: 3.123180


Epoch 288:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 288 done — avg loss: 3.122818


Epoch 289:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 289 done — avg loss: 3.121871


Epoch 290:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 290 done — avg loss: 3.120476


Epoch 291:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 291 done — avg loss: 3.120170


Epoch 292:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 292 done — avg loss: 3.118889


Epoch 293:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 293 done — avg loss: 3.117864


Epoch 294:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 294 done — avg loss: 3.116067


Epoch 295:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 295 done — avg loss: 3.114903


Epoch 296:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 296 done — avg loss: 3.114414


Epoch 297:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 297 done — avg loss: 3.112078


Epoch 298:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 298 done — avg loss: 3.110643


Epoch 299:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 299 done — avg loss: 3.110008


Epoch 300:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 300 done — avg loss: 3.108469


Epoch 301:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 301 done — avg loss: 3.108195


Epoch 302:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 302 done — avg loss: 3.107998


Epoch 303:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 303 done — avg loss: 3.105351


Epoch 304:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 304 done — avg loss: 3.104635


Epoch 305:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 305 done — avg loss: 3.103151


Epoch 306:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 306 done — avg loss: 3.102011


Epoch 307:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 307 done — avg loss: 3.101728


Epoch 308:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 308 done — avg loss: 3.100751


Epoch 309:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 309 done — avg loss: 3.099647


Epoch 310:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 310 done — avg loss: 3.098393


Epoch 311:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 311 done — avg loss: 3.096580


Epoch 312:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 312 done — avg loss: 3.096940


Epoch 313:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 313 done — avg loss: 3.094855


Epoch 314:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 314 done — avg loss: 3.093839


Epoch 315:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 315 done — avg loss: 3.092260


Epoch 316:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 316 done — avg loss: 3.091560


Epoch 317:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 317 done — avg loss: 3.090356


Epoch 318:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 318 done — avg loss: 3.090605


Epoch 319:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 319 done — avg loss: 3.088177


Epoch 320:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 320 done — avg loss: 3.087912


Epoch 321:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 321 done — avg loss: 3.086094


Epoch 322:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 322 done — avg loss: 3.085033


Epoch 323:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 323 done — avg loss: 3.084272


Epoch 324:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 324 done — avg loss: 3.083304


Epoch 325:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 325 done — avg loss: 3.083291


Epoch 326:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 326 done — avg loss: 3.081849


Epoch 327:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 327 done — avg loss: 3.080368


Epoch 328:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 328 done — avg loss: 3.079067


Epoch 329:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 329 done — avg loss: 3.077543


Epoch 330:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 330 done — avg loss: 3.076868


Epoch 331:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 331 done — avg loss: 3.076002


Epoch 332:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 332 done — avg loss: 3.074175


Epoch 333:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 333 done — avg loss: 3.073041


Epoch 334:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 334 done — avg loss: 3.073240


Epoch 335:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 335 done — avg loss: 3.071042


Epoch 336:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 336 done — avg loss: 3.070342


Epoch 337:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 337 done — avg loss: 3.068340


Epoch 338:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 338 done — avg loss: 3.068562


Epoch 339:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 339 done — avg loss: 3.067042


Epoch 340:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 340 done — avg loss: 3.065935


Epoch 341:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 341 done — avg loss: 3.066555


Epoch 342:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 342 done — avg loss: 3.064944


Epoch 343:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 343 done — avg loss: 3.063414


Epoch 344:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 344 done — avg loss: 3.061913


Epoch 345:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 345 done — avg loss: 3.060290


Epoch 346:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 346 done — avg loss: 3.059946


Epoch 347:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 347 done — avg loss: 3.058956


Epoch 348:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 348 done — avg loss: 3.058074


Epoch 349:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 349 done — avg loss: 3.056681


Epoch 350:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 350 done — avg loss: 3.055030


Epoch 351:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 351 done — avg loss: 3.054712


Epoch 352:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 352 done — avg loss: 3.053022


Epoch 353:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 353 done — avg loss: 3.053579


Epoch 354:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 354 done — avg loss: 3.052693


Epoch 355:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 355 done — avg loss: 3.050701


Epoch 356:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 356 done — avg loss: 3.048823


Epoch 357:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 357 done — avg loss: 3.048566


Epoch 358:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 358 done — avg loss: 3.047347


Epoch 359:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 359 done — avg loss: 3.045789


Epoch 360:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 360 done — avg loss: 3.046084


Epoch 361:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 361 done — avg loss: 3.044635


Epoch 362:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 362 done — avg loss: 3.044850


Epoch 363:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 363 done — avg loss: 3.044724


Epoch 364:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 364 done — avg loss: 3.042828


Epoch 365:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 365 done — avg loss: 3.041286


Epoch 366:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 366 done — avg loss: 3.040064


Epoch 367:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 367 done — avg loss: 3.037515


Epoch 368:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 368 done — avg loss: 3.036250


Epoch 369:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 369 done — avg loss: 3.035875


Epoch 370:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 370 done — avg loss: 3.034681


Epoch 371:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 371 done — avg loss: 3.034660


Epoch 372:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 372 done — avg loss: 3.033091


Epoch 373:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 373 done — avg loss: 3.030871


Epoch 374:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 374 done — avg loss: 3.030520


Epoch 375:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 375 done — avg loss: 3.028877


Epoch 376:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 376 done — avg loss: 3.028219


Epoch 377:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 377 done — avg loss: 3.026867


Epoch 378:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 378 done — avg loss: 3.026477


Epoch 379:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 379 done — avg loss: 3.025259


Epoch 380:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 380 done — avg loss: 3.024719


Epoch 381:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 381 done — avg loss: 3.024208


Epoch 382:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 382 done — avg loss: 3.022426


Epoch 383:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 383 done — avg loss: 3.021179


Epoch 384:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 384 done — avg loss: 3.019611


Epoch 385:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 385 done — avg loss: 3.018150


Epoch 386:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 386 done — avg loss: 3.018166


Epoch 387:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 387 done — avg loss: 3.017401


Epoch 388:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 388 done — avg loss: 3.015292


Epoch 389:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 389 done — avg loss: 3.015462


Epoch 390:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 390 done — avg loss: 3.015189


Epoch 391:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 391 done — avg loss: 3.013525


Epoch 392:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 392 done — avg loss: 3.012099


Epoch 393:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 393 done — avg loss: 3.011201


Epoch 394:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 394 done — avg loss: 3.010381


Epoch 395:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 395 done — avg loss: 3.010268


Epoch 396:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 396 done — avg loss: 3.007785


Epoch 397:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 397 done — avg loss: 3.008175


Epoch 398:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 398 done — avg loss: 3.006482


Epoch 399:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 399 done — avg loss: 3.005304


Epoch 400:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 400 done — avg loss: 3.004260


Epoch 401:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 401 done — avg loss: 3.003347


Epoch 402:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 402 done — avg loss: 3.002289


Epoch 403:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 403 done — avg loss: 3.002950


Epoch 404:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 404 done — avg loss: 3.002168


Epoch 405:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 405 done — avg loss: 3.000070


Epoch 406:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 406 done — avg loss: 2.999519


Epoch 407:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 407 done — avg loss: 2.998518


Epoch 408:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 408 done — avg loss: 2.998282


Epoch 409:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 409 done — avg loss: 2.997416


Epoch 410:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 410 done — avg loss: 2.996216


Epoch 411:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 411 done — avg loss: 2.995234


Epoch 412:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 412 done — avg loss: 2.993697


Epoch 413:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 413 done — avg loss: 2.992283


Epoch 414:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 414 done — avg loss: 2.990817


Epoch 415:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 415 done — avg loss: 2.990707


Epoch 416:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 416 done — avg loss: 2.989237


Epoch 417:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 417 done — avg loss: 2.987589


Epoch 418:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 418 done — avg loss: 2.987403


Epoch 419:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 419 done — avg loss: 2.986354


Epoch 420:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 420 done — avg loss: 2.984959


Epoch 421:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 421 done — avg loss: 2.986196


Epoch 422:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 422 done — avg loss: 2.985226


Epoch 423:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 423 done — avg loss: 2.983755


Epoch 424:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 424 done — avg loss: 2.982939


Epoch 425:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 425 done — avg loss: 2.981385


Epoch 426:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 426 done — avg loss: 2.980004


Epoch 427:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 427 done — avg loss: 2.977922


Epoch 428:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 428 done — avg loss: 2.977324


Epoch 429:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 429 done — avg loss: 2.976598


Epoch 430:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 430 done — avg loss: 2.975720


Epoch 431:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 431 done — avg loss: 2.974483


Epoch 432:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 432 done — avg loss: 2.973538


Epoch 433:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 433 done — avg loss: 2.973712


Epoch 434:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 434 done — avg loss: 2.974606


Epoch 435:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 435 done — avg loss: 2.972614


Epoch 436:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 436 done — avg loss: 2.970096


Epoch 437:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 437 done — avg loss: 2.969031


Epoch 438:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 438 done — avg loss: 2.968378


Epoch 439:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 439 done — avg loss: 2.967564


Epoch 440:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 440 done — avg loss: 2.966463


Epoch 441:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 441 done — avg loss: 2.967072


Epoch 442:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 442 done — avg loss: 2.966604


Epoch 443:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 443 done — avg loss: 2.965075


Epoch 444:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 444 done — avg loss: 2.964273


Epoch 445:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 445 done — avg loss: 2.962218


Epoch 446:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 446 done — avg loss: 2.961466


Epoch 447:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 447 done — avg loss: 2.961590


Epoch 448:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 448 done — avg loss: 2.959467


Epoch 449:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 449 done — avg loss: 2.957792


Epoch 450:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 450 done — avg loss: 2.956870


Epoch 451:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 451 done — avg loss: 2.956835


Epoch 452:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 452 done — avg loss: 2.955552


Epoch 453:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 453 done — avg loss: 2.955098


Epoch 454:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 454 done — avg loss: 2.954520


Epoch 455:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 455 done — avg loss: 2.953125


Epoch 456:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 456 done — avg loss: 2.952632


Epoch 457:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 457 done — avg loss: 2.953328


Epoch 458:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 458 done — avg loss: 2.950524


Epoch 459:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 459 done — avg loss: 2.949933


Epoch 460:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 460 done — avg loss: 2.949744


Epoch 461:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 461 done — avg loss: 2.948663


Epoch 462:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 462 done — avg loss: 2.948362


Epoch 463:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 463 done — avg loss: 2.948243


Epoch 464:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 464 done — avg loss: 2.946515


Epoch 465:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 465 done — avg loss: 2.945023


Epoch 466:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 466 done — avg loss: 2.945486


Epoch 467:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 467 done — avg loss: 2.943032


Epoch 468:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 468 done — avg loss: 2.941715


Epoch 469:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 469 done — avg loss: 2.941524


Epoch 470:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 470 done — avg loss: 2.939568


Epoch 471:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 471 done — avg loss: 2.940370


Epoch 472:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 472 done — avg loss: 2.939514


Epoch 473:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 473 done — avg loss: 2.936993


Epoch 474:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 474 done — avg loss: 2.936579


Epoch 475:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 475 done — avg loss: 2.935234


Epoch 476:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 476 done — avg loss: 2.935046


Epoch 477:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 477 done — avg loss: 2.934475


Epoch 478:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 478 done — avg loss: 2.934572


Epoch 479:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 479 done — avg loss: 2.934329


Epoch 480:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 480 done — avg loss: 2.931834


Epoch 481:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 481 done — avg loss: 2.930770


Epoch 482:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 482 done — avg loss: 2.931122


Epoch 483:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 483 done — avg loss: 2.929389


Epoch 484:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 484 done — avg loss: 2.927483


Epoch 485:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 485 done — avg loss: 2.928388


Epoch 486:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 486 done — avg loss: 2.928127


Epoch 487:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 487 done — avg loss: 2.926511


Epoch 488:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 488 done — avg loss: 2.924089


Epoch 489:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 489 done — avg loss: 2.923913


Epoch 490:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 490 done — avg loss: 2.923037


Epoch 491:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 491 done — avg loss: 2.922260


Epoch 492:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 492 done — avg loss: 2.922034


Epoch 493:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 493 done — avg loss: 2.922146


Epoch 494:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 494 done — avg loss: 2.920044


Epoch 495:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 495 done — avg loss: 2.919173


Epoch 496:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 496 done — avg loss: 2.917941


Epoch 497:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 497 done — avg loss: 2.917628


Epoch 498:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 498 done — avg loss: 2.915981


Epoch 499:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 499 done — avg loss: 2.915011


Epoch 500:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 500 done — avg loss: 2.914692


In [None]:
model.eval();

In [None]:
chunk_data = encode_with_soft_prompt(
    dataset, model, soft_prompt)

Encoding chunks:   0%|          | 0/25 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/668 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/732 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/572 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/589 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/577 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/546 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/453 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/505 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/453 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/568 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/500 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/511 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/487 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/456 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/465 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/514 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/426 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/537 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/494 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/600 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/473 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/494 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/495 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/442 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/452 [00:00<?, ?it/s]

Encoding done in 217.47s


In [None]:
all_rec = decode_with_soft_prompt(
    chunk_data, model, soft_prompt)

Decoding chunks:   0%|          | 0/25 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/668 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/732 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/572 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/589 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/577 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/546 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/453 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/505 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/453 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/568 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/500 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/511 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/487 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/456 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/465 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/514 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/426 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/537 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/494 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/600 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/473 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/494 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/495 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/442 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/452 [00:00<?, ?it/s]

Decoding done in 200.10s


In [None]:
orig_ids = sum((entry['ids'].tolist() for entry in chunk_data), [])
encoded_output = sum([entry['bits'] for entry in chunk_data], [])

In [None]:
compare_sequences(orig_ids, all_rec)

✅  Совпадают полностью!


In [None]:
compression_stats(data, encoded_output)

Исходный размер:   400000 бит
Размер после сжатия: 54256 бит
Коэффициент сжатия: 0.1356


### 100 эпох обучения

In [None]:
dataset, loader = create_data_loader(tokenizer, chunk_size=2000, max_bytes=50000)

In [None]:
soft_prompt = create_soft_prompt(200, model)

In [None]:
model.train()
optimizer = torch.optim.Adam([soft_prompt], lr=5e-4)

In [None]:
train_soft_prompt(loader, model, soft_prompt, optimizer, 100)

Epoch 1:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 1 done — avg loss: 3.684460


Epoch 2:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 2 done — avg loss: 3.669685


Epoch 3:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 3 done — avg loss: 3.657713


Epoch 4:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 4 done — avg loss: 3.646687


Epoch 5:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 5 done — avg loss: 3.636851


Epoch 6:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 6 done — avg loss: 3.627662


Epoch 7:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 7 done — avg loss: 3.618940


Epoch 8:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 8 done — avg loss: 3.611479


Epoch 9:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 9 done — avg loss: 3.604457


Epoch 10:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 10 done — avg loss: 3.598027


Epoch 11:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 11 done — avg loss: 3.591756


Epoch 12:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 12 done — avg loss: 3.585716


Epoch 13:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 13 done — avg loss: 3.579705


Epoch 14:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 14 done — avg loss: 3.574540


Epoch 15:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 15 done — avg loss: 3.569151


Epoch 16:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 16 done — avg loss: 3.563954


Epoch 17:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 17 done — avg loss: 3.559332


Epoch 18:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 18 done — avg loss: 3.554615


Epoch 19:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 19 done — avg loss: 3.550216


Epoch 20:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 20 done — avg loss: 3.545719


Epoch 21:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 21 done — avg loss: 3.541339


Epoch 22:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 22 done — avg loss: 3.536806


Epoch 23:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 23 done — avg loss: 3.533169


Epoch 24:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 24 done — avg loss: 3.529127


Epoch 25:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 25 done — avg loss: 3.525066


Epoch 26:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 26 done — avg loss: 3.521486


Epoch 27:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 27 done — avg loss: 3.517454


Epoch 28:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 28 done — avg loss: 3.513973


Epoch 29:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 29 done — avg loss: 3.510627


Epoch 30:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 30 done — avg loss: 3.507068


Epoch 31:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 31 done — avg loss: 3.503801


Epoch 32:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 32 done — avg loss: 3.500651


Epoch 33:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 33 done — avg loss: 3.497624


Epoch 34:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 34 done — avg loss: 3.494686


Epoch 35:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 35 done — avg loss: 3.491310


Epoch 36:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 36 done — avg loss: 3.488552


Epoch 37:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 37 done — avg loss: 3.485822


Epoch 38:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 38 done — avg loss: 3.483820


Epoch 39:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 39 done — avg loss: 3.481296


Epoch 40:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 40 done — avg loss: 3.478790


Epoch 41:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 41 done — avg loss: 3.475467


Epoch 42:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 42 done — avg loss: 3.472717


Epoch 43:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 43 done — avg loss: 3.470477


Epoch 44:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 44 done — avg loss: 3.468314


Epoch 45:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 45 done — avg loss: 3.466685


Epoch 46:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 46 done — avg loss: 3.464330


Epoch 47:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 47 done — avg loss: 3.461662


Epoch 48:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 48 done — avg loss: 3.459222


Epoch 49:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 49 done — avg loss: 3.456907


Epoch 50:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 50 done — avg loss: 3.454901


Epoch 51:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 51 done — avg loss: 3.452752


Epoch 52:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 52 done — avg loss: 3.450718


Epoch 53:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 53 done — avg loss: 3.449061


Epoch 54:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 54 done — avg loss: 3.447294


Epoch 55:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 55 done — avg loss: 3.445389


Epoch 56:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 56 done — avg loss: 3.443580


Epoch 57:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 57 done — avg loss: 3.441349


Epoch 58:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 58 done — avg loss: 3.439132


Epoch 59:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 59 done — avg loss: 3.437415


Epoch 60:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 60 done — avg loss: 3.435336


Epoch 61:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 61 done — avg loss: 3.433556


Epoch 62:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 62 done — avg loss: 3.431819


Epoch 63:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 63 done — avg loss: 3.429807


Epoch 64:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 64 done — avg loss: 3.428225


Epoch 65:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 65 done — avg loss: 3.426935


Epoch 66:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 66 done — avg loss: 3.425465


Epoch 67:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 67 done — avg loss: 3.423425


Epoch 68:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 68 done — avg loss: 3.421976


Epoch 69:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 69 done — avg loss: 3.419745


Epoch 70:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 70 done — avg loss: 3.418121


Epoch 71:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 71 done — avg loss: 3.416483


Epoch 72:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 72 done — avg loss: 3.415014


Epoch 73:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 73 done — avg loss: 3.412728


Epoch 74:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 74 done — avg loss: 3.410738


Epoch 75:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 75 done — avg loss: 3.409286


Epoch 76:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 76 done — avg loss: 3.407229


Epoch 77:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 77 done — avg loss: 3.406057


Epoch 78:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 78 done — avg loss: 3.404347


Epoch 79:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 79 done — avg loss: 3.402555


Epoch 80:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 80 done — avg loss: 3.401524


Epoch 81:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 81 done — avg loss: 3.399893


Epoch 82:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 82 done — avg loss: 3.398504


Epoch 83:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 83 done — avg loss: 3.396343


Epoch 84:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 84 done — avg loss: 3.394281


Epoch 85:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 85 done — avg loss: 3.392721


Epoch 86:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 86 done — avg loss: 3.391344


Epoch 87:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 87 done — avg loss: 3.389900


Epoch 88:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 88 done — avg loss: 3.388667


Epoch 89:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 89 done — avg loss: 3.387566


Epoch 90:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 90 done — avg loss: 3.386124


Epoch 91:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 91 done — avg loss: 3.384131


Epoch 92:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 92 done — avg loss: 3.382702


Epoch 93:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 93 done — avg loss: 3.380780


Epoch 94:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 94 done — avg loss: 3.378911


Epoch 95:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 95 done — avg loss: 3.377382


Epoch 96:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 96 done — avg loss: 3.375954


Epoch 97:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 97 done — avg loss: 3.374703


Epoch 98:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 98 done — avg loss: 3.372458


Epoch 99:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 99 done — avg loss: 3.371199


Epoch 100:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 100 done — avg loss: 3.370519


In [None]:
model.eval();

In [None]:
chunk_data = encode_with_soft_prompt(
    dataset, model, soft_prompt)

Encoding chunks:   0%|          | 0/25 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/668 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/732 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/572 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/589 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/577 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/546 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/453 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/505 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/453 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/568 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/500 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/511 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/487 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/456 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/465 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/514 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/426 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/537 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/494 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/600 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/473 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/494 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/495 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/442 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/452 [00:00<?, ?it/s]

Encoding done in 194.38s


In [None]:
all_rec = decode_with_soft_prompt(
    chunk_data, model, soft_prompt)

Decoding chunks:   0%|          | 0/25 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/668 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/732 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/572 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/589 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/577 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/546 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/453 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/505 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/453 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/568 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/500 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/511 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/487 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/456 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/465 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/514 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/426 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/537 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/494 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/600 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/473 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/494 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/495 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/442 [00:00<?, ?it/s]

  Tokens in chunk:   0%|          | 0/452 [00:00<?, ?it/s]

Decoding done in 195.70s


In [None]:
orig_ids = sum((entry['ids'].tolist() for entry in chunk_data), [])
encoded_output = sum([entry['bits'] for entry in chunk_data], [])

In [None]:
compare_sequences(orig_ids, all_rec)

✅  Совпадают полностью!


In [None]:
compression_stats(data, encoded_output)

Исходный размер:   400000 бит
Размер после сжатия: 62589 бит
Коэффициент сжатия: 0.1565
