In [4]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from datasets import Dataset, load_dataset
from tqdm.auto import tqdm
from functools import partial
from einops import rearrange
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
from baukit import TraceDict

# Download the modelz
device = "cuda:0"
model_name="EleutherAI/Pythia-70M-deduped"
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# set seed
torch.manual_seed(0)
np.random.seed(0)

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
from autoencoders import *
# ae_model_id = ["jbrinkma/Pythia-70M-chess_sp51_r4_gpt_neox.layers.1", "jbrinkma/Pythia-70M-chess_sp51_r4_gpt_neox.layers.2.mlp"]
model_id = "jbrinkma/Pythia-70M-deduped-SAEs"
autoencoders = []
layers = model.config.num_hidden_layers
cache_names = [(f"gpt_neox.layers.{i}", f"gpt_neox.layers.{i+1}.mlp") for i in range(layers-1)]
# cache_names = [(f"gpt_neox.layers.{i}", f"gpt_neox.layers.{i+1}.attention") for i in range(layers-1)]
# cache_names = [(f"gpt_neox.layers.{i}", f"gpt_neox.layers.{i+1}.mlp") for i in range(1,2)]
num_layers = len(cache_names)
cache_names = [item for sublist in cache_names for item in sublist]
filenames = [(f"Pythia-70M-deduped-{i}.pt", f"Pythia-70M-deduped-mlp-{i+1}.pt") for i in range(layers-1)]
# filenames = [(f"Pythia-70M-deduped-{i}.pt", f"Pythia-70M-deduped-attention-{i+1}.pt") for i in range(layers-1)]
# filenames = [(f"Pythia-70M-deduped-{i}.pt", f"Pythia-70M-deduped-mlp-{i+1}.pt") for i in range(1,2)]
filenames = [item for sublist in filenames for item in sublist]
for filen in filenames:
    ae_download_location = hf_hub_download(repo_id=model_id, filename=filen)
    autoencoder = torch.load(ae_download_location)
    autoencoder.to_device(device)
    # Freeze autoencoder weights
    autoencoder.encoder.requires_grad_(False)
    autoencoder.encoder_bias.requires_grad_(False)
    autoencoders.append(autoencoder)

In [6]:
from activation_dataset import chunk_and_tokenize
# Download the dataset
# This formats it, so every datapoint is max_length tokens long
# The batch size is for loading activations from the LLM, not for inference on the autoencoder
dataset_name="stas/openwebtext-10k"
max_seq_length=256
dataset = load_dataset(dataset_name, split="train")
dataset, _ = chunk_and_tokenize(dataset, tokenizer, max_length=max_seq_length)
max_tokens = dataset.num_rows*max_seq_length
print(f"Number of tokens: {max_tokens/1e6:.2f}M")

Found cached dataset openwebtext-10k (/root/.cache/huggingface/datasets/stas___openwebtext-10k/plain_text/1.0.0/3a8df094c671b4cb63ed0b41f40fb3bd855e9ce2e3765e5df50abcdfb5ec144b)
Loading cached processed dataset at /root/.cache/huggingface/datasets/stas___openwebtext-10k/plain_text/1.0.0/3a8df094c671b4cb63ed0b41f40fb3bd855e9ce2e3765e5df50abcdfb5ec144b/cache-89f7d956ed9de0f6_*_of_00008.arrow


Number of tokens: 11.23M


In [7]:
# split dataset into both a train & test set
test_size = 0.1
test_size = int(len(dataset)*test_size)
train_size = len(dataset)-test_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
print(f"Number of Train Tokens: {len(train_dataset)*max_seq_length/1e6:.2f}M")
print(f"Number of Test Tokens: {len(test_dataset)*max_seq_length/1e6:.2f}M")

Number of Train Tokens: 10.10M
Number of Test Tokens: 1.12M


In [8]:
# Create dataloaders
batch_size = 32
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=8)

In [9]:
from torch.optim import Adam
from torch import nn
    
class mlp(nn.Module):
    def __init__(self, input_size, output_size=None, hidden_size=None, bias=True):
        super().__init__()
        if(output_size is None):
            output_size = input_size
        if(hidden_size is None):
            hidden_size = input_size
        # Only adding a bias to the latent layer
        self.linear = nn.Linear(input_size, hidden_size, bias=bias)
        self.linear2 = nn.Linear(hidden_size, output_size, bias=False)
        self.relu = nn.ReLU()
    def forward(self, x):
        x = self.linear(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x
    def latent(self, x):
        x = self.linear(x)
        x = self.relu(x)
        return x
    
class linear(nn.Module):
    def __init__(self, input_size, output_size=None, bias=True):
        super().__init__()
        if(output_size is None):
            output_size = input_size
        self.linear = nn.Linear(input_size, output_size, bias=bias)
    def forward(self, x):
        x = self.linear(x)
        return x

In [10]:
# Load in the features indices
# Note: linear & MLP features are a subet of alive features, not total features
import pickle
with open("linear_features.pkl", "rb") as f:
    linear_features = pickle.load(f)
with open("mlp_features.pkl", "rb") as f:
    mlp_features = pickle.load(f)
with open("alive_features.pkl", "rb") as f:
    alive_features_ind = pickle.load(f)
normalization_per_layer = torch.load("normalization_per_layer.pt")

In [11]:
num_features, d_model = autoencoder.encoder.shape

mlp_weights = []
mlp_optimizers = []
for auto_ind in range(num_layers):
    mlp_weights.append(mlp(num_features, bias=False))
    mlp_weights[-1].to(device)
    mlp_optimizers.append(Adam(mlp_weights[-1].parameters(), lr=1e-3))

In [24]:
for i, batch in enumerate(tqdm(test_dataloader)):
    batch = batch["input_ids"].to(device)
    with torch.no_grad():
        with TraceDict(model, cache_names) as ret:
            original_logits = model(batch)
    break

  0%|          | 0/138 [00:01<?, ?it/s]


In [35]:
logits = original_logits.logits

# Assuming logits and batch are provided as follows
# logits.shape = (batch_size, seq, vocab_size)
# batch.shape = (batch_size, seq)

# Reshape logits and labels for CrossEntropyLoss
logits = logits.view(-1, logits.size(-1))  # Reshape to (batch_size * seq, vocab_size)
labels = batch.view(-1)  # Flatten to (batch_size * seq)

# Define the loss function
loss_fn = nn.CrossEntropyLoss()

# Calculate loss
loss = loss_fn(logits, labels)

# Print or return the loss
print(loss)

tensor(12.4122, device='cuda:0')


torch.Size([804864, 512])

In [30]:
batch.shape

torch.Size([32, 256])

In [15]:
# Find the arguments that model() takes in (ie def of model.forward())
arguments = model.forward.__code__.co_varnames
arguments

('self',
 'input_ids',
 'attention_mask',
 'position_ids',
 'inputs_embeds',
 'head_mask',
 'past_key_values',
 'labels',
 'use_cache',
 'output_attentions',
 'output_hidden_states',
 'return_dict',
 'outputs',
 'hidden_states',
 'lm_logits',
 'lm_loss',
 'shift_logits',
 'loss_fct',
 'output')

In [20]:
def x(a = 1, b = 2, c = 3):
    return a+b+c
x.__code__.co_varnames

('a', 'b', 'c')

In [41]:
from baukit import TraceDict, Trace
def replace_with_reconstructed_features(mlp_out, layer_name):
    # We will have the true activations at MLP_out, & the reconstruction. 
    return x_hat

sparse_weights = mlp_weights
ce = nn.CrossEntropyLoss()

original_loss = [0]*num_layers
reconstructed_loss = [0]*num_layers
loss_diff = [0]*num_layers
# Cross entropy check
for i, batch in enumerate(tqdm(test_dataloader)):
    batch = batch["input_ids"].to(device)
    with torch.no_grad():
        with TraceDict(model, cache_names) as ret:
            original_logits = model(batch).logits
        for cache_name_ind in range(len(cache_names) // 2):
            sparse_weight = sparse_weights[cache_name_ind]
            input_cache_name = cache_names[cache_name_ind * 2]
            input_autoencoder = autoencoders[cache_name_ind * 2]
            input_activations = ret[input_cache_name].output
            if isinstance(input_activations, tuple):
                input_activations = input_activations[0]
            input_internal_activations = rearrange(input_activations, "b s n -> (b s) n")
            input_internal_activations = input_autoencoder.encode(input_internal_activations)
            x_hat = sparse_weights[cache_name_ind](input_internal_activations)
            # Decode x_hat
            output_autoencoder = autoencoders[cache_name_ind * 2 + 1]
            x_hat = output_autoencoder.decode(x_hat)
            # Rearrange x_hat back to b s n
            x_hat = rearrange(x_hat, "(b s) n -> b s n", b=batch.shape[0], s=batch.shape[1])
            # replace the activations with the reconstructed activations
            output_cache_name = cache_names[cache_name_ind * 2 + 1]
            with Trace(model, output_cache_name, edit_output=replace_with_reconstructed_features) as ret_2:
                logits = model(batch).logits
            logits = logits.view(-1, logits.size(-1))  # Reshape to (batch_size * seq, vocab_size)
            original_logits = original_logits.view(-1, original_logits.size(-1))  # Reshape to (batch_size * seq, vocab_size)
            labels = batch.view(-1)  # Flatten to (batch_size * seq)

            # Define the loss function
            ce = nn.CrossEntropyLoss()
            # Calculate loss
            recon_loss = ce(logits, labels)
            orig_loss = ce(original_logits, labels)
            loss_diff[cache_name_ind] += recon_loss - orig_loss
            reconstructed_loss[cache_name_ind] += recon_loss
            original_loss[cache_name_ind] += orig_loss
            divide_by = i + 1
            print(f"Layer:{cache_name_ind+1} | Original Loss: {original_loss[cache_name_ind]/divide_by:.2f} | Reconstructed Loss: {reconstructed_loss[cache_name_ind]/divide_by:.2f} | Difference: {loss_diff[cache_name_ind]/divide_by:.2f}")
# Average
reconstructed_loss /= len(dataset)
original_loss /= len(dataset)
loss_diff /= len(dataset)
print(f"Original Loss: {original_loss:.2f}")
print(f"Reconstructed Loss: {reconstructed_loss:.2f}")
# Find percentage of loss is reconstructed
print(f"Percentage of loss reconstructed: {100*(original_loss-reconstructed_loss)/original_loss:.2f}%")

  0%|          | 0/138 [00:00<?, ?it/s]

Layer:1 | Original Loss: 10.93 | Reconstructed Loss: 9.30 | Difference: -1.63
Layer:2 | Original Loss: 10.93 | Reconstructed Loss: 9.73 | Difference: -1.20
Layer:3 | Original Loss: 10.93 | Reconstructed Loss: 10.80 | Difference: -0.13
Layer:4 | Original Loss: 10.93 | Reconstructed Loss: 9.44 | Difference: -1.48


  1%|          | 1/138 [00:03<07:33,  3.31s/it]

Layer:5 | Original Loss: 10.93 | Reconstructed Loss: 16.37 | Difference: 5.44
Layer:1 | Original Loss: 10.86 | Reconstructed Loss: 9.34 | Difference: -1.52
Layer:2 | Original Loss: 10.86 | Reconstructed Loss: 9.69 | Difference: -1.17
Layer:3 | Original Loss: 10.86 | Reconstructed Loss: 10.68 | Difference: -0.18
Layer:4 | Original Loss: 10.86 | Reconstructed Loss: 9.44 | Difference: -1.42


  1%|▏         | 2/138 [00:05<05:25,  2.39s/it]

Layer:5 | Original Loss: 10.86 | Reconstructed Loss: 16.28 | Difference: 5.42
Layer:1 | Original Loss: 10.94 | Reconstructed Loss: 9.41 | Difference: -1.53
Layer:2 | Original Loss: 10.94 | Reconstructed Loss: 9.75 | Difference: -1.19
Layer:3 | Original Loss: 10.94 | Reconstructed Loss: 10.75 | Difference: -0.19
Layer:4 | Original Loss: 10.94 | Reconstructed Loss: 9.55 | Difference: -1.39


  2%|▏         | 3/138 [00:06<04:44,  2.10s/it]

Layer:5 | Original Loss: 10.94 | Reconstructed Loss: 16.34 | Difference: 5.40
Layer:1 | Original Loss: 10.90 | Reconstructed Loss: 9.35 | Difference: -1.54
Layer:2 | Original Loss: 10.90 | Reconstructed Loss: 9.71 | Difference: -1.18
Layer:3 | Original Loss: 10.90 | Reconstructed Loss: 10.71 | Difference: -0.18
Layer:4 | Original Loss: 10.90 | Reconstructed Loss: 9.52 | Difference: -1.37


  3%|▎         | 4/138 [00:08<04:23,  1.97s/it]

Layer:5 | Original Loss: 10.90 | Reconstructed Loss: 16.33 | Difference: 5.43
Layer:1 | Original Loss: 10.94 | Reconstructed Loss: 9.38 | Difference: -1.56
Layer:2 | Original Loss: 10.94 | Reconstructed Loss: 9.76 | Difference: -1.18
Layer:3 | Original Loss: 10.94 | Reconstructed Loss: 10.75 | Difference: -0.20
Layer:4 | Original Loss: 10.94 | Reconstructed Loss: 9.57 | Difference: -1.37


  4%|▎         | 5/138 [00:10<04:11,  1.89s/it]

Layer:5 | Original Loss: 10.94 | Reconstructed Loss: 16.38 | Difference: 5.43
Layer:1 | Original Loss: 10.92 | Reconstructed Loss: 9.38 | Difference: -1.54
Layer:2 | Original Loss: 10.92 | Reconstructed Loss: 9.73 | Difference: -1.19
Layer:3 | Original Loss: 10.92 | Reconstructed Loss: 10.72 | Difference: -0.20
Layer:4 | Original Loss: 10.92 | Reconstructed Loss: 9.55 | Difference: -1.37


  4%|▍         | 6/138 [00:12<04:03,  1.85s/it]

Layer:5 | Original Loss: 10.92 | Reconstructed Loss: 16.37 | Difference: 5.45
Layer:1 | Original Loss: 10.92 | Reconstructed Loss: 9.41 | Difference: -1.51
Layer:2 | Original Loss: 10.92 | Reconstructed Loss: 9.73 | Difference: -1.19
Layer:3 | Original Loss: 10.92 | Reconstructed Loss: 10.72 | Difference: -0.20
Layer:4 | Original Loss: 10.92 | Reconstructed Loss: 9.54 | Difference: -1.38


  5%|▌         | 7/138 [00:13<03:58,  1.82s/it]

Layer:5 | Original Loss: 10.92 | Reconstructed Loss: 16.39 | Difference: 5.47
Layer:1 | Original Loss: 10.92 | Reconstructed Loss: 9.40 | Difference: -1.52
Layer:2 | Original Loss: 10.92 | Reconstructed Loss: 9.72 | Difference: -1.20
Layer:3 | Original Loss: 10.92 | Reconstructed Loss: 10.71 | Difference: -0.21
Layer:4 | Original Loss: 10.92 | Reconstructed Loss: 9.54 | Difference: -1.38


  6%|▌         | 8/138 [00:15<03:53,  1.80s/it]

Layer:5 | Original Loss: 10.92 | Reconstructed Loss: 16.42 | Difference: 5.50
Layer:1 | Original Loss: 10.87 | Reconstructed Loss: 9.35 | Difference: -1.51
Layer:2 | Original Loss: 10.87 | Reconstructed Loss: 9.68 | Difference: -1.18
Layer:3 | Original Loss: 10.87 | Reconstructed Loss: 10.65 | Difference: -0.22
Layer:4 | Original Loss: 10.87 | Reconstructed Loss: 9.49 | Difference: -1.38


  7%|▋         | 9/138 [00:17<03:50,  1.78s/it]

Layer:5 | Original Loss: 10.87 | Reconstructed Loss: 16.39 | Difference: 5.52
Layer:1 | Original Loss: 10.91 | Reconstructed Loss: 9.38 | Difference: -1.53
Layer:2 | Original Loss: 10.91 | Reconstructed Loss: 9.72 | Difference: -1.19
Layer:3 | Original Loss: 10.91 | Reconstructed Loss: 10.69 | Difference: -0.22
Layer:4 | Original Loss: 10.91 | Reconstructed Loss: 9.53 | Difference: -1.38


  7%|▋         | 10/138 [00:19<03:47,  1.78s/it]

Layer:5 | Original Loss: 10.91 | Reconstructed Loss: 16.43 | Difference: 5.52
Layer:1 | Original Loss: 10.91 | Reconstructed Loss: 9.38 | Difference: -1.53
Layer:2 | Original Loss: 10.91 | Reconstructed Loss: 9.72 | Difference: -1.20
Layer:3 | Original Loss: 10.91 | Reconstructed Loss: 10.70 | Difference: -0.21
Layer:4 | Original Loss: 10.91 | Reconstructed Loss: 9.54 | Difference: -1.37


  8%|▊         | 11/138 [00:20<03:44,  1.77s/it]

Layer:5 | Original Loss: 10.91 | Reconstructed Loss: 16.42 | Difference: 5.51
Layer:1 | Original Loss: 11.09 | Reconstructed Loss: 9.55 | Difference: -1.54
Layer:2 | Original Loss: 11.09 | Reconstructed Loss: 9.89 | Difference: -1.20
Layer:3 | Original Loss: 11.09 | Reconstructed Loss: 10.88 | Difference: -0.22
Layer:4 | Original Loss: 11.09 | Reconstructed Loss: 9.71 | Difference: -1.38


  9%|▊         | 12/138 [00:22<03:42,  1.77s/it]

Layer:5 | Original Loss: 11.09 | Reconstructed Loss: 16.59 | Difference: 5.50
Layer:1 | Original Loss: 11.10 | Reconstructed Loss: 9.55 | Difference: -1.55
Layer:2 | Original Loss: 11.10 | Reconstructed Loss: 9.89 | Difference: -1.21
Layer:3 | Original Loss: 11.10 | Reconstructed Loss: 10.88 | Difference: -0.22
Layer:4 | Original Loss: 11.10 | Reconstructed Loss: 9.71 | Difference: -1.39


  9%|▉         | 13/138 [00:24<03:41,  1.77s/it]

Layer:5 | Original Loss: 11.10 | Reconstructed Loss: 16.59 | Difference: 5.49
Layer:1 | Original Loss: 11.16 | Reconstructed Loss: 9.62 | Difference: -1.55
Layer:2 | Original Loss: 11.16 | Reconstructed Loss: 9.95 | Difference: -1.21
Layer:3 | Original Loss: 11.16 | Reconstructed Loss: 10.93 | Difference: -0.23
Layer:4 | Original Loss: 11.16 | Reconstructed Loss: 9.77 | Difference: -1.39


 10%|█         | 14/138 [00:26<03:39,  1.77s/it]

Layer:5 | Original Loss: 11.16 | Reconstructed Loss: 16.63 | Difference: 5.47
Layer:1 | Original Loss: 11.36 | Reconstructed Loss: 9.79 | Difference: -1.57
Layer:2 | Original Loss: 11.36 | Reconstructed Loss: 10.14 | Difference: -1.22
Layer:3 | Original Loss: 11.36 | Reconstructed Loss: 11.11 | Difference: -0.25
Layer:4 | Original Loss: 11.36 | Reconstructed Loss: 9.94 | Difference: -1.42


 11%|█         | 15/138 [00:27<03:37,  1.77s/it]

Layer:5 | Original Loss: 11.36 | Reconstructed Loss: 16.80 | Difference: 5.44
Layer:1 | Original Loss: 11.34 | Reconstructed Loss: 9.76 | Difference: -1.57
Layer:2 | Original Loss: 11.34 | Reconstructed Loss: 10.11 | Difference: -1.22
Layer:3 | Original Loss: 11.34 | Reconstructed Loss: 11.09 | Difference: -0.25
Layer:4 | Original Loss: 11.34 | Reconstructed Loss: 9.93 | Difference: -1.41


 12%|█▏        | 16/138 [00:29<03:35,  1.77s/it]

Layer:5 | Original Loss: 11.34 | Reconstructed Loss: 16.76 | Difference: 5.42
Layer:1 | Original Loss: 11.36 | Reconstructed Loss: 9.78 | Difference: -1.57
Layer:2 | Original Loss: 11.36 | Reconstructed Loss: 10.13 | Difference: -1.22
Layer:3 | Original Loss: 11.36 | Reconstructed Loss: 11.11 | Difference: -0.25
Layer:4 | Original Loss: 11.36 | Reconstructed Loss: 9.94 | Difference: -1.42


 12%|█▏        | 17/138 [00:31<03:33,  1.77s/it]

Layer:5 | Original Loss: 11.36 | Reconstructed Loss: 16.79 | Difference: 5.43
Layer:1 | Original Loss: 11.32 | Reconstructed Loss: 9.75 | Difference: -1.57
Layer:2 | Original Loss: 11.32 | Reconstructed Loss: 10.10 | Difference: -1.22
Layer:3 | Original Loss: 11.32 | Reconstructed Loss: 11.08 | Difference: -0.25
Layer:4 | Original Loss: 11.32 | Reconstructed Loss: 9.91 | Difference: -1.41


 13%|█▎        | 18/138 [00:33<03:33,  1.78s/it]

Layer:5 | Original Loss: 11.32 | Reconstructed Loss: 16.76 | Difference: 5.44
Layer:1 | Original Loss: 11.26 | Reconstructed Loss: 9.69 | Difference: -1.57
Layer:2 | Original Loss: 11.26 | Reconstructed Loss: 10.04 | Difference: -1.22
Layer:3 | Original Loss: 11.26 | Reconstructed Loss: 11.01 | Difference: -0.24
Layer:4 | Original Loss: 11.26 | Reconstructed Loss: 9.85 | Difference: -1.41


 14%|█▍        | 19/138 [00:35<03:31,  1.77s/it]

Layer:5 | Original Loss: 11.26 | Reconstructed Loss: 16.70 | Difference: 5.44
Layer:1 | Original Loss: 11.29 | Reconstructed Loss: 9.73 | Difference: -1.56
Layer:2 | Original Loss: 11.29 | Reconstructed Loss: 10.07 | Difference: -1.22
Layer:3 | Original Loss: 11.29 | Reconstructed Loss: 11.05 | Difference: -0.25
Layer:4 | Original Loss: 11.29 | Reconstructed Loss: 9.88 | Difference: -1.41


 14%|█▍        | 20/138 [00:36<03:29,  1.77s/it]

Layer:5 | Original Loss: 11.29 | Reconstructed Loss: 16.73 | Difference: 5.43
Layer:1 | Original Loss: 11.32 | Reconstructed Loss: 9.75 | Difference: -1.57
Layer:2 | Original Loss: 11.32 | Reconstructed Loss: 10.09 | Difference: -1.23
Layer:3 | Original Loss: 11.32 | Reconstructed Loss: 11.07 | Difference: -0.25
Layer:4 | Original Loss: 11.32 | Reconstructed Loss: 9.91 | Difference: -1.41


 15%|█▌        | 21/138 [00:38<03:27,  1.77s/it]

Layer:5 | Original Loss: 11.32 | Reconstructed Loss: 16.75 | Difference: 5.43
Layer:1 | Original Loss: 11.38 | Reconstructed Loss: 9.81 | Difference: -1.57
Layer:2 | Original Loss: 11.38 | Reconstructed Loss: 10.15 | Difference: -1.23
Layer:3 | Original Loss: 11.38 | Reconstructed Loss: 11.14 | Difference: -0.25
Layer:4 | Original Loss: 11.38 | Reconstructed Loss: 9.96 | Difference: -1.42


 16%|█▌        | 22/138 [00:40<03:25,  1.77s/it]

Layer:5 | Original Loss: 11.38 | Reconstructed Loss: 16.82 | Difference: 5.43
Layer:1 | Original Loss: 11.98 | Reconstructed Loss: 10.38 | Difference: -1.60
Layer:2 | Original Loss: 11.98 | Reconstructed Loss: 10.70 | Difference: -1.28
Layer:3 | Original Loss: 11.98 | Reconstructed Loss: 11.77 | Difference: -0.21
Layer:4 | Original Loss: 11.98 | Reconstructed Loss: 10.51 | Difference: -1.47


 17%|█▋        | 23/138 [00:42<03:24,  1.77s/it]

Layer:5 | Original Loss: 11.98 | Reconstructed Loss: 17.54 | Difference: 5.56
Layer:1 | Original Loss: 11.91 | Reconstructed Loss: 10.30 | Difference: -1.60
Layer:2 | Original Loss: 11.91 | Reconstructed Loss: 10.63 | Difference: -1.28
Layer:3 | Original Loss: 11.91 | Reconstructed Loss: 11.70 | Difference: -0.21
Layer:4 | Original Loss: 11.91 | Reconstructed Loss: 10.44 | Difference: -1.46


 17%|█▋        | 24/138 [00:43<03:22,  1.77s/it]

Layer:5 | Original Loss: 11.91 | Reconstructed Loss: 17.47 | Difference: 5.56
Layer:1 | Original Loss: 11.84 | Reconstructed Loss: 10.25 | Difference: -1.60
Layer:2 | Original Loss: 11.84 | Reconstructed Loss: 10.57 | Difference: -1.28
Layer:3 | Original Loss: 11.84 | Reconstructed Loss: 11.63 | Difference: -0.21
Layer:4 | Original Loss: 11.84 | Reconstructed Loss: 10.39 | Difference: -1.46


 18%|█▊        | 25/138 [00:45<03:20,  1.78s/it]

Layer:5 | Original Loss: 11.84 | Reconstructed Loss: 17.40 | Difference: 5.56
Layer:1 | Original Loss: 11.82 | Reconstructed Loss: 10.23 | Difference: -1.59
Layer:2 | Original Loss: 11.82 | Reconstructed Loss: 10.54 | Difference: -1.27
Layer:3 | Original Loss: 11.82 | Reconstructed Loss: 11.60 | Difference: -0.22
Layer:4 | Original Loss: 11.82 | Reconstructed Loss: 10.36 | Difference: -1.45


 19%|█▉        | 26/138 [00:47<03:18,  1.78s/it]

Layer:5 | Original Loss: 11.82 | Reconstructed Loss: 17.39 | Difference: 5.57
Layer:1 | Original Loss: 11.79 | Reconstructed Loss: 10.20 | Difference: -1.59
Layer:2 | Original Loss: 11.79 | Reconstructed Loss: 10.52 | Difference: -1.27
Layer:3 | Original Loss: 11.79 | Reconstructed Loss: 11.57 | Difference: -0.22
Layer:4 | Original Loss: 11.79 | Reconstructed Loss: 10.34 | Difference: -1.45


 20%|█▉        | 27/138 [00:49<03:17,  1.78s/it]

Layer:5 | Original Loss: 11.79 | Reconstructed Loss: 17.35 | Difference: 5.56
Layer:1 | Original Loss: 11.86 | Reconstructed Loss: 10.27 | Difference: -1.59
Layer:2 | Original Loss: 11.86 | Reconstructed Loss: 10.59 | Difference: -1.27
Layer:3 | Original Loss: 11.86 | Reconstructed Loss: 11.65 | Difference: -0.21
Layer:4 | Original Loss: 11.86 | Reconstructed Loss: 10.41 | Difference: -1.46


 20%|██        | 28/138 [00:51<03:15,  1.78s/it]

Layer:5 | Original Loss: 11.86 | Reconstructed Loss: 17.41 | Difference: 5.55
Layer:1 | Original Loss: 11.85 | Reconstructed Loss: 10.26 | Difference: -1.59
Layer:2 | Original Loss: 11.85 | Reconstructed Loss: 10.58 | Difference: -1.27
Layer:3 | Original Loss: 11.85 | Reconstructed Loss: 11.64 | Difference: -0.21
Layer:4 | Original Loss: 11.85 | Reconstructed Loss: 10.40 | Difference: -1.45


 21%|██        | 29/138 [00:52<03:13,  1.78s/it]

Layer:5 | Original Loss: 11.85 | Reconstructed Loss: 17.40 | Difference: 5.55
Layer:1 | Original Loss: 11.82 | Reconstructed Loss: 10.22 | Difference: -1.59
Layer:2 | Original Loss: 11.82 | Reconstructed Loss: 10.55 | Difference: -1.27
Layer:3 | Original Loss: 11.82 | Reconstructed Loss: 11.61 | Difference: -0.21
Layer:4 | Original Loss: 11.82 | Reconstructed Loss: 10.37 | Difference: -1.45


 22%|██▏       | 30/138 [00:54<03:12,  1.78s/it]

Layer:5 | Original Loss: 11.82 | Reconstructed Loss: 17.37 | Difference: 5.55
Layer:1 | Original Loss: 11.79 | Reconstructed Loss: 10.20 | Difference: -1.59
Layer:2 | Original Loss: 11.79 | Reconstructed Loss: 10.53 | Difference: -1.27
Layer:3 | Original Loss: 11.79 | Reconstructed Loss: 11.59 | Difference: -0.21
Layer:4 | Original Loss: 11.79 | Reconstructed Loss: 10.35 | Difference: -1.45


 22%|██▏       | 31/138 [00:56<03:10,  1.78s/it]

Layer:5 | Original Loss: 11.79 | Reconstructed Loss: 17.33 | Difference: 5.54
Layer:1 | Original Loss: 11.76 | Reconstructed Loss: 10.17 | Difference: -1.59
Layer:2 | Original Loss: 11.76 | Reconstructed Loss: 10.50 | Difference: -1.26
Layer:3 | Original Loss: 11.76 | Reconstructed Loss: 11.55 | Difference: -0.21
Layer:4 | Original Loss: 11.76 | Reconstructed Loss: 10.32 | Difference: -1.44


 23%|██▎       | 32/138 [00:58<03:08,  1.78s/it]

Layer:5 | Original Loss: 11.76 | Reconstructed Loss: 17.29 | Difference: 5.53
Layer:1 | Original Loss: 11.74 | Reconstructed Loss: 10.15 | Difference: -1.58
Layer:2 | Original Loss: 11.74 | Reconstructed Loss: 10.48 | Difference: -1.26
Layer:3 | Original Loss: 11.74 | Reconstructed Loss: 11.53 | Difference: -0.21
Layer:4 | Original Loss: 11.74 | Reconstructed Loss: 10.30 | Difference: -1.44


 24%|██▍       | 33/138 [00:59<03:07,  1.78s/it]

Layer:5 | Original Loss: 11.74 | Reconstructed Loss: 17.27 | Difference: 5.53
Layer:1 | Original Loss: 11.72 | Reconstructed Loss: 10.14 | Difference: -1.58
Layer:2 | Original Loss: 11.72 | Reconstructed Loss: 10.46 | Difference: -1.26
Layer:3 | Original Loss: 11.72 | Reconstructed Loss: 11.51 | Difference: -0.21
Layer:4 | Original Loss: 11.72 | Reconstructed Loss: 10.28 | Difference: -1.44


 25%|██▍       | 34/138 [01:01<03:05,  1.78s/it]

Layer:5 | Original Loss: 11.72 | Reconstructed Loss: 17.25 | Difference: 5.53
Layer:1 | Original Loss: 11.70 | Reconstructed Loss: 10.12 | Difference: -1.58
Layer:2 | Original Loss: 11.70 | Reconstructed Loss: 10.44 | Difference: -1.26
Layer:3 | Original Loss: 11.70 | Reconstructed Loss: 11.50 | Difference: -0.21
Layer:4 | Original Loss: 11.70 | Reconstructed Loss: 10.26 | Difference: -1.44


 25%|██▌       | 35/138 [01:03<03:03,  1.78s/it]

Layer:5 | Original Loss: 11.70 | Reconstructed Loss: 17.23 | Difference: 5.53
Layer:1 | Original Loss: 11.67 | Reconstructed Loss: 10.09 | Difference: -1.58
Layer:2 | Original Loss: 11.67 | Reconstructed Loss: 10.41 | Difference: -1.26
Layer:3 | Original Loss: 11.67 | Reconstructed Loss: 11.47 | Difference: -0.21
Layer:4 | Original Loss: 11.67 | Reconstructed Loss: 10.23 | Difference: -1.44


 26%|██▌       | 36/138 [01:05<03:02,  1.79s/it]

Layer:5 | Original Loss: 11.67 | Reconstructed Loss: 17.21 | Difference: 5.53
Layer:1 | Original Loss: 11.66 | Reconstructed Loss: 10.08 | Difference: -1.58
Layer:2 | Original Loss: 11.66 | Reconstructed Loss: 10.40 | Difference: -1.26
Layer:3 | Original Loss: 11.66 | Reconstructed Loss: 11.45 | Difference: -0.21
Layer:4 | Original Loss: 11.66 | Reconstructed Loss: 10.22 | Difference: -1.44


 27%|██▋       | 37/138 [01:07<03:00,  1.79s/it]

Layer:5 | Original Loss: 11.66 | Reconstructed Loss: 17.19 | Difference: 5.53
Layer:1 | Original Loss: 11.64 | Reconstructed Loss: 10.07 | Difference: -1.58
Layer:2 | Original Loss: 11.64 | Reconstructed Loss: 10.39 | Difference: -1.25
Layer:3 | Original Loss: 11.64 | Reconstructed Loss: 11.43 | Difference: -0.21
Layer:4 | Original Loss: 11.64 | Reconstructed Loss: 10.21 | Difference: -1.44


 28%|██▊       | 38/138 [01:08<02:58,  1.79s/it]

Layer:5 | Original Loss: 11.64 | Reconstructed Loss: 17.17 | Difference: 5.53
Layer:1 | Original Loss: 11.63 | Reconstructed Loss: 10.05 | Difference: -1.57
Layer:2 | Original Loss: 11.63 | Reconstructed Loss: 10.38 | Difference: -1.25
Layer:3 | Original Loss: 11.63 | Reconstructed Loss: 11.42 | Difference: -0.21
Layer:4 | Original Loss: 11.63 | Reconstructed Loss: 10.19 | Difference: -1.43


 28%|██▊       | 39/138 [01:10<02:57,  1.79s/it]

Layer:5 | Original Loss: 11.63 | Reconstructed Loss: 17.16 | Difference: 5.53
Layer:1 | Original Loss: 11.62 | Reconstructed Loss: 10.04 | Difference: -1.58
Layer:2 | Original Loss: 11.62 | Reconstructed Loss: 10.37 | Difference: -1.25
Layer:3 | Original Loss: 11.62 | Reconstructed Loss: 11.41 | Difference: -0.21
Layer:4 | Original Loss: 11.62 | Reconstructed Loss: 10.19 | Difference: -1.43


 29%|██▉       | 40/138 [01:12<02:55,  1.79s/it]

Layer:5 | Original Loss: 11.62 | Reconstructed Loss: 17.14 | Difference: 5.53
Layer:1 | Original Loss: 11.61 | Reconstructed Loss: 10.03 | Difference: -1.58
Layer:2 | Original Loss: 11.61 | Reconstructed Loss: 10.36 | Difference: -1.25
Layer:3 | Original Loss: 11.61 | Reconstructed Loss: 11.40 | Difference: -0.21
Layer:4 | Original Loss: 11.61 | Reconstructed Loss: 10.18 | Difference: -1.43


 30%|██▉       | 41/138 [01:14<02:53,  1.79s/it]

Layer:5 | Original Loss: 11.61 | Reconstructed Loss: 17.13 | Difference: 5.53
Layer:1 | Original Loss: 11.61 | Reconstructed Loss: 10.03 | Difference: -1.58
Layer:2 | Original Loss: 11.61 | Reconstructed Loss: 10.36 | Difference: -1.25
Layer:3 | Original Loss: 11.61 | Reconstructed Loss: 11.40 | Difference: -0.21
Layer:4 | Original Loss: 11.61 | Reconstructed Loss: 10.18 | Difference: -1.43


 30%|███       | 42/138 [01:16<02:52,  1.79s/it]

Layer:5 | Original Loss: 11.61 | Reconstructed Loss: 17.13 | Difference: 5.52
Layer:1 | Original Loss: 11.60 | Reconstructed Loss: 10.02 | Difference: -1.58
Layer:2 | Original Loss: 11.60 | Reconstructed Loss: 10.35 | Difference: -1.25
Layer:3 | Original Loss: 11.60 | Reconstructed Loss: 11.39 | Difference: -0.21
Layer:4 | Original Loss: 11.60 | Reconstructed Loss: 10.17 | Difference: -1.43


 31%|███       | 43/138 [01:17<02:50,  1.80s/it]

Layer:5 | Original Loss: 11.60 | Reconstructed Loss: 17.12 | Difference: 5.52
Layer:1 | Original Loss: 11.59 | Reconstructed Loss: 10.01 | Difference: -1.58
Layer:2 | Original Loss: 11.59 | Reconstructed Loss: 10.34 | Difference: -1.25
Layer:3 | Original Loss: 11.59 | Reconstructed Loss: 11.38 | Difference: -0.21
Layer:4 | Original Loss: 11.59 | Reconstructed Loss: 10.16 | Difference: -1.43


 32%|███▏      | 44/138 [01:19<02:49,  1.80s/it]

Layer:5 | Original Loss: 11.59 | Reconstructed Loss: 17.11 | Difference: 5.52
Layer:1 | Original Loss: 11.59 | Reconstructed Loss: 10.01 | Difference: -1.58
Layer:2 | Original Loss: 11.59 | Reconstructed Loss: 10.34 | Difference: -1.25
Layer:3 | Original Loss: 11.59 | Reconstructed Loss: 11.38 | Difference: -0.21
Layer:4 | Original Loss: 11.59 | Reconstructed Loss: 10.16 | Difference: -1.43


 33%|███▎      | 45/138 [01:21<02:47,  1.80s/it]

Layer:5 | Original Loss: 11.59 | Reconstructed Loss: 17.12 | Difference: 5.52
Layer:1 | Original Loss: 11.58 | Reconstructed Loss: 10.00 | Difference: -1.58
Layer:2 | Original Loss: 11.58 | Reconstructed Loss: 10.33 | Difference: -1.25
Layer:3 | Original Loss: 11.58 | Reconstructed Loss: 11.37 | Difference: -0.21
Layer:4 | Original Loss: 11.58 | Reconstructed Loss: 10.15 | Difference: -1.43


 33%|███▎      | 46/138 [01:23<02:45,  1.80s/it]

Layer:5 | Original Loss: 11.58 | Reconstructed Loss: 17.11 | Difference: 5.52
Layer:1 | Original Loss: 11.57 | Reconstructed Loss: 9.98 | Difference: -1.58
Layer:2 | Original Loss: 11.57 | Reconstructed Loss: 10.32 | Difference: -1.25
Layer:3 | Original Loss: 11.57 | Reconstructed Loss: 11.35 | Difference: -0.21
Layer:4 | Original Loss: 11.57 | Reconstructed Loss: 10.14 | Difference: -1.43


 34%|███▍      | 47/138 [01:25<02:43,  1.80s/it]

Layer:5 | Original Loss: 11.57 | Reconstructed Loss: 17.08 | Difference: 5.52
Layer:1 | Original Loss: 11.57 | Reconstructed Loss: 9.99 | Difference: -1.58
Layer:2 | Original Loss: 11.57 | Reconstructed Loss: 10.32 | Difference: -1.25
Layer:3 | Original Loss: 11.57 | Reconstructed Loss: 11.36 | Difference: -0.21
Layer:4 | Original Loss: 11.57 | Reconstructed Loss: 10.14 | Difference: -1.43


 35%|███▍      | 48/138 [01:26<02:42,  1.80s/it]

Layer:5 | Original Loss: 11.57 | Reconstructed Loss: 17.08 | Difference: 5.52
Layer:1 | Original Loss: 11.55 | Reconstructed Loss: 9.98 | Difference: -1.58
Layer:2 | Original Loss: 11.55 | Reconstructed Loss: 10.31 | Difference: -1.25
Layer:3 | Original Loss: 11.55 | Reconstructed Loss: 11.34 | Difference: -0.21
Layer:4 | Original Loss: 11.55 | Reconstructed Loss: 10.13 | Difference: -1.43


 36%|███▌      | 49/138 [01:28<02:40,  1.80s/it]

Layer:5 | Original Loss: 11.55 | Reconstructed Loss: 17.07 | Difference: 5.52
Layer:1 | Original Loss: 11.52 | Reconstructed Loss: 9.95 | Difference: -1.57
Layer:2 | Original Loss: 11.52 | Reconstructed Loss: 10.27 | Difference: -1.25
Layer:3 | Original Loss: 11.52 | Reconstructed Loss: 11.31 | Difference: -0.21
Layer:4 | Original Loss: 11.52 | Reconstructed Loss: 10.10 | Difference: -1.42


 36%|███▌      | 50/138 [01:30<02:38,  1.80s/it]

Layer:5 | Original Loss: 11.52 | Reconstructed Loss: 17.03 | Difference: 5.51
Layer:1 | Original Loss: 11.53 | Reconstructed Loss: 9.96 | Difference: -1.57
Layer:2 | Original Loss: 11.53 | Reconstructed Loss: 10.29 | Difference: -1.25
Layer:3 | Original Loss: 11.53 | Reconstructed Loss: 11.33 | Difference: -0.21
Layer:4 | Original Loss: 11.53 | Reconstructed Loss: 10.11 | Difference: -1.42


 37%|███▋      | 51/138 [01:32<02:36,  1.80s/it]

Layer:5 | Original Loss: 11.53 | Reconstructed Loss: 17.05 | Difference: 5.52
Layer:1 | Original Loss: 11.51 | Reconstructed Loss: 9.94 | Difference: -1.57
Layer:2 | Original Loss: 11.51 | Reconstructed Loss: 10.27 | Difference: -1.25
Layer:3 | Original Loss: 11.51 | Reconstructed Loss: 11.31 | Difference: -0.21
Layer:4 | Original Loss: 11.51 | Reconstructed Loss: 10.09 | Difference: -1.42


 38%|███▊      | 52/138 [01:34<02:35,  1.80s/it]

Layer:5 | Original Loss: 11.51 | Reconstructed Loss: 17.03 | Difference: 5.51
Layer:1 | Original Loss: 11.51 | Reconstructed Loss: 9.94 | Difference: -1.57
Layer:2 | Original Loss: 11.51 | Reconstructed Loss: 10.26 | Difference: -1.24
Layer:3 | Original Loss: 11.51 | Reconstructed Loss: 11.30 | Difference: -0.21
Layer:4 | Original Loss: 11.51 | Reconstructed Loss: 10.09 | Difference: -1.42


 38%|███▊      | 53/138 [01:35<02:33,  1.80s/it]

Layer:5 | Original Loss: 11.51 | Reconstructed Loss: 17.02 | Difference: 5.51
Layer:1 | Original Loss: 11.49 | Reconstructed Loss: 9.92 | Difference: -1.57
Layer:2 | Original Loss: 11.49 | Reconstructed Loss: 10.24 | Difference: -1.24
Layer:3 | Original Loss: 11.49 | Reconstructed Loss: 11.28 | Difference: -0.21
Layer:4 | Original Loss: 11.49 | Reconstructed Loss: 10.07 | Difference: -1.42


 39%|███▉      | 54/138 [01:37<02:31,  1.81s/it]

Layer:5 | Original Loss: 11.49 | Reconstructed Loss: 16.99 | Difference: 5.51
Layer:1 | Original Loss: 11.49 | Reconstructed Loss: 9.92 | Difference: -1.57
Layer:2 | Original Loss: 11.49 | Reconstructed Loss: 10.25 | Difference: -1.24
Layer:3 | Original Loss: 11.49 | Reconstructed Loss: 11.28 | Difference: -0.21
Layer:4 | Original Loss: 11.49 | Reconstructed Loss: 10.07 | Difference: -1.41


 40%|███▉      | 55/138 [01:39<02:29,  1.81s/it]

Layer:5 | Original Loss: 11.49 | Reconstructed Loss: 16.99 | Difference: 5.50
Layer:1 | Original Loss: 11.49 | Reconstructed Loss: 9.92 | Difference: -1.57
Layer:2 | Original Loss: 11.49 | Reconstructed Loss: 10.25 | Difference: -1.24
Layer:3 | Original Loss: 11.49 | Reconstructed Loss: 11.28 | Difference: -0.21
Layer:4 | Original Loss: 11.49 | Reconstructed Loss: 10.07 | Difference: -1.41


 41%|████      | 56/138 [01:41<02:28,  1.81s/it]

Layer:5 | Original Loss: 11.49 | Reconstructed Loss: 16.99 | Difference: 5.50
Layer:1 | Original Loss: 11.48 | Reconstructed Loss: 9.92 | Difference: -1.56
Layer:2 | Original Loss: 11.48 | Reconstructed Loss: 10.24 | Difference: -1.24
Layer:3 | Original Loss: 11.48 | Reconstructed Loss: 11.27 | Difference: -0.21
Layer:4 | Original Loss: 11.48 | Reconstructed Loss: 10.07 | Difference: -1.41


 41%|████▏     | 57/138 [01:43<02:26,  1.81s/it]

Layer:5 | Original Loss: 11.48 | Reconstructed Loss: 16.98 | Difference: 5.50
Layer:1 | Original Loss: 11.49 | Reconstructed Loss: 9.93 | Difference: -1.56
Layer:2 | Original Loss: 11.49 | Reconstructed Loss: 10.25 | Difference: -1.24
Layer:3 | Original Loss: 11.49 | Reconstructed Loss: 11.28 | Difference: -0.21
Layer:4 | Original Loss: 11.49 | Reconstructed Loss: 10.08 | Difference: -1.41


 42%|████▏     | 58/138 [01:44<02:24,  1.81s/it]

Layer:5 | Original Loss: 11.49 | Reconstructed Loss: 16.99 | Difference: 5.50
Layer:1 | Original Loss: 11.49 | Reconstructed Loss: 9.93 | Difference: -1.56
Layer:2 | Original Loss: 11.49 | Reconstructed Loss: 10.25 | Difference: -1.24
Layer:3 | Original Loss: 11.49 | Reconstructed Loss: 11.28 | Difference: -0.21
Layer:4 | Original Loss: 11.49 | Reconstructed Loss: 10.08 | Difference: -1.41


 43%|████▎     | 59/138 [01:46<02:23,  1.81s/it]

Layer:5 | Original Loss: 11.49 | Reconstructed Loss: 16.99 | Difference: 5.50
Layer:1 | Original Loss: 11.48 | Reconstructed Loss: 9.92 | Difference: -1.56
Layer:2 | Original Loss: 11.48 | Reconstructed Loss: 10.24 | Difference: -1.24
Layer:3 | Original Loss: 11.48 | Reconstructed Loss: 11.27 | Difference: -0.21
Layer:4 | Original Loss: 11.48 | Reconstructed Loss: 10.07 | Difference: -1.41


 43%|████▎     | 60/138 [01:48<02:21,  1.81s/it]

Layer:5 | Original Loss: 11.48 | Reconstructed Loss: 16.98 | Difference: 5.50
Layer:1 | Original Loss: 11.50 | Reconstructed Loss: 9.93 | Difference: -1.56
Layer:2 | Original Loss: 11.50 | Reconstructed Loss: 10.26 | Difference: -1.24
Layer:3 | Original Loss: 11.50 | Reconstructed Loss: 11.29 | Difference: -0.21
Layer:4 | Original Loss: 11.50 | Reconstructed Loss: 10.08 | Difference: -1.41


 44%|████▍     | 61/138 [01:50<02:19,  1.81s/it]

Layer:5 | Original Loss: 11.50 | Reconstructed Loss: 17.00 | Difference: 5.51
Layer:1 | Original Loss: 11.54 | Reconstructed Loss: 9.97 | Difference: -1.57
Layer:2 | Original Loss: 11.54 | Reconstructed Loss: 10.30 | Difference: -1.24
Layer:3 | Original Loss: 11.54 | Reconstructed Loss: 11.33 | Difference: -0.21
Layer:4 | Original Loss: 11.54 | Reconstructed Loss: 10.12 | Difference: -1.42


 45%|████▍     | 62/138 [01:52<02:17,  1.81s/it]

Layer:5 | Original Loss: 11.54 | Reconstructed Loss: 17.04 | Difference: 5.49
Layer:1 | Original Loss: 11.53 | Reconstructed Loss: 9.97 | Difference: -1.57
Layer:2 | Original Loss: 11.53 | Reconstructed Loss: 10.29 | Difference: -1.24
Layer:3 | Original Loss: 11.53 | Reconstructed Loss: 11.32 | Difference: -0.21
Layer:4 | Original Loss: 11.53 | Reconstructed Loss: 10.12 | Difference: -1.42


 46%|████▌     | 63/138 [01:54<02:16,  1.82s/it]

Layer:5 | Original Loss: 11.53 | Reconstructed Loss: 17.03 | Difference: 5.49
Layer:1 | Original Loss: 11.53 | Reconstructed Loss: 9.96 | Difference: -1.56
Layer:2 | Original Loss: 11.53 | Reconstructed Loss: 10.29 | Difference: -1.24
Layer:3 | Original Loss: 11.53 | Reconstructed Loss: 11.32 | Difference: -0.21
Layer:4 | Original Loss: 11.53 | Reconstructed Loss: 10.11 | Difference: -1.42


 46%|████▋     | 64/138 [01:55<02:14,  1.81s/it]

Layer:5 | Original Loss: 11.53 | Reconstructed Loss: 17.02 | Difference: 5.49
Layer:1 | Original Loss: 11.53 | Reconstructed Loss: 9.97 | Difference: -1.56
Layer:2 | Original Loss: 11.53 | Reconstructed Loss: 10.29 | Difference: -1.24
Layer:3 | Original Loss: 11.53 | Reconstructed Loss: 11.32 | Difference: -0.21
Layer:4 | Original Loss: 11.53 | Reconstructed Loss: 10.11 | Difference: -1.42


 47%|████▋     | 65/138 [01:57<02:12,  1.82s/it]

Layer:5 | Original Loss: 11.53 | Reconstructed Loss: 17.02 | Difference: 5.49
Layer:1 | Original Loss: 11.52 | Reconstructed Loss: 9.96 | Difference: -1.56
Layer:2 | Original Loss: 11.52 | Reconstructed Loss: 10.28 | Difference: -1.24
Layer:3 | Original Loss: 11.52 | Reconstructed Loss: 11.31 | Difference: -0.21
Layer:4 | Original Loss: 11.52 | Reconstructed Loss: 10.11 | Difference: -1.42


 48%|████▊     | 66/138 [01:59<02:10,  1.82s/it]

Layer:5 | Original Loss: 11.52 | Reconstructed Loss: 17.01 | Difference: 5.49
Layer:1 | Original Loss: 11.54 | Reconstructed Loss: 9.97 | Difference: -1.56
Layer:2 | Original Loss: 11.54 | Reconstructed Loss: 10.30 | Difference: -1.24
Layer:3 | Original Loss: 11.54 | Reconstructed Loss: 11.32 | Difference: -0.22
Layer:4 | Original Loss: 11.54 | Reconstructed Loss: 10.12 | Difference: -1.42


 49%|████▊     | 67/138 [02:01<02:09,  1.82s/it]

Layer:5 | Original Loss: 11.54 | Reconstructed Loss: 17.02 | Difference: 5.48
Layer:1 | Original Loss: 11.53 | Reconstructed Loss: 9.97 | Difference: -1.56
Layer:2 | Original Loss: 11.53 | Reconstructed Loss: 10.29 | Difference: -1.24
Layer:3 | Original Loss: 11.53 | Reconstructed Loss: 11.32 | Difference: -0.22
Layer:4 | Original Loss: 11.53 | Reconstructed Loss: 10.12 | Difference: -1.41


 49%|████▉     | 68/138 [02:03<02:07,  1.82s/it]

Layer:5 | Original Loss: 11.53 | Reconstructed Loss: 17.01 | Difference: 5.48
Layer:1 | Original Loss: 11.54 | Reconstructed Loss: 9.97 | Difference: -1.56
Layer:2 | Original Loss: 11.54 | Reconstructed Loss: 10.30 | Difference: -1.24
Layer:3 | Original Loss: 11.54 | Reconstructed Loss: 11.32 | Difference: -0.22
Layer:4 | Original Loss: 11.54 | Reconstructed Loss: 10.12 | Difference: -1.41


 50%|█████     | 69/138 [02:04<02:05,  1.82s/it]

Layer:5 | Original Loss: 11.54 | Reconstructed Loss: 17.02 | Difference: 5.48
Layer:1 | Original Loss: 11.53 | Reconstructed Loss: 9.97 | Difference: -1.56
Layer:2 | Original Loss: 11.53 | Reconstructed Loss: 10.29 | Difference: -1.24
Layer:3 | Original Loss: 11.53 | Reconstructed Loss: 11.31 | Difference: -0.22
Layer:4 | Original Loss: 11.53 | Reconstructed Loss: 10.12 | Difference: -1.41


 51%|█████     | 70/138 [02:06<02:03,  1.82s/it]

Layer:5 | Original Loss: 11.53 | Reconstructed Loss: 17.01 | Difference: 5.48
Layer:1 | Original Loss: 11.52 | Reconstructed Loss: 9.96 | Difference: -1.56
Layer:2 | Original Loss: 11.52 | Reconstructed Loss: 10.28 | Difference: -1.24
Layer:3 | Original Loss: 11.52 | Reconstructed Loss: 11.31 | Difference: -0.22
Layer:4 | Original Loss: 11.52 | Reconstructed Loss: 10.11 | Difference: -1.41


 51%|█████▏    | 71/138 [02:08<02:02,  1.82s/it]

Layer:5 | Original Loss: 11.52 | Reconstructed Loss: 17.00 | Difference: 5.48
Layer:1 | Original Loss: 11.52 | Reconstructed Loss: 9.96 | Difference: -1.56
Layer:2 | Original Loss: 11.52 | Reconstructed Loss: 10.28 | Difference: -1.24
Layer:3 | Original Loss: 11.52 | Reconstructed Loss: 11.30 | Difference: -0.22
Layer:4 | Original Loss: 11.52 | Reconstructed Loss: 10.11 | Difference: -1.41


 52%|█████▏    | 72/138 [02:10<02:00,  1.82s/it]

Layer:5 | Original Loss: 11.52 | Reconstructed Loss: 17.00 | Difference: 5.48
Layer:1 | Original Loss: 11.52 | Reconstructed Loss: 9.96 | Difference: -1.56
Layer:2 | Original Loss: 11.52 | Reconstructed Loss: 10.28 | Difference: -1.24
Layer:3 | Original Loss: 11.52 | Reconstructed Loss: 11.31 | Difference: -0.22
Layer:4 | Original Loss: 11.52 | Reconstructed Loss: 10.11 | Difference: -1.41


 53%|█████▎    | 73/138 [02:12<01:58,  1.82s/it]

Layer:5 | Original Loss: 11.52 | Reconstructed Loss: 17.01 | Difference: 5.48
Layer:1 | Original Loss: 11.52 | Reconstructed Loss: 9.96 | Difference: -1.56
Layer:2 | Original Loss: 11.52 | Reconstructed Loss: 10.28 | Difference: -1.24
Layer:3 | Original Loss: 11.52 | Reconstructed Loss: 11.30 | Difference: -0.22
Layer:4 | Original Loss: 11.52 | Reconstructed Loss: 10.11 | Difference: -1.41


 54%|█████▎    | 74/138 [02:14<01:56,  1.82s/it]

Layer:5 | Original Loss: 11.52 | Reconstructed Loss: 17.00 | Difference: 5.48
Layer:1 | Original Loss: 11.51 | Reconstructed Loss: 9.95 | Difference: -1.56
Layer:2 | Original Loss: 11.51 | Reconstructed Loss: 10.27 | Difference: -1.24
Layer:3 | Original Loss: 11.51 | Reconstructed Loss: 11.29 | Difference: -0.22
Layer:4 | Original Loss: 11.51 | Reconstructed Loss: 10.10 | Difference: -1.41


 54%|█████▍    | 75/138 [02:15<01:55,  1.83s/it]

Layer:5 | Original Loss: 11.51 | Reconstructed Loss: 16.99 | Difference: 5.48
Layer:1 | Original Loss: 11.51 | Reconstructed Loss: 9.95 | Difference: -1.56
Layer:2 | Original Loss: 11.51 | Reconstructed Loss: 10.26 | Difference: -1.24
Layer:3 | Original Loss: 11.51 | Reconstructed Loss: 11.29 | Difference: -0.22
Layer:4 | Original Loss: 11.51 | Reconstructed Loss: 10.09 | Difference: -1.41


 55%|█████▌    | 76/138 [02:17<01:53,  1.83s/it]

Layer:5 | Original Loss: 11.51 | Reconstructed Loss: 16.99 | Difference: 5.48
Layer:1 | Original Loss: 11.51 | Reconstructed Loss: 9.95 | Difference: -1.56
Layer:2 | Original Loss: 11.51 | Reconstructed Loss: 10.27 | Difference: -1.24
Layer:3 | Original Loss: 11.51 | Reconstructed Loss: 11.29 | Difference: -0.22
Layer:4 | Original Loss: 11.51 | Reconstructed Loss: 10.10 | Difference: -1.41


 56%|█████▌    | 77/138 [02:19<01:51,  1.83s/it]

Layer:5 | Original Loss: 11.51 | Reconstructed Loss: 16.98 | Difference: 5.48
Layer:1 | Original Loss: 11.50 | Reconstructed Loss: 9.94 | Difference: -1.56
Layer:2 | Original Loss: 11.50 | Reconstructed Loss: 10.26 | Difference: -1.24
Layer:3 | Original Loss: 11.50 | Reconstructed Loss: 11.28 | Difference: -0.22
Layer:4 | Original Loss: 11.50 | Reconstructed Loss: 10.09 | Difference: -1.41


 57%|█████▋    | 78/138 [02:21<01:49,  1.83s/it]

Layer:5 | Original Loss: 11.50 | Reconstructed Loss: 16.97 | Difference: 5.48
Layer:1 | Original Loss: 11.59 | Reconstructed Loss: 10.03 | Difference: -1.56
Layer:2 | Original Loss: 11.59 | Reconstructed Loss: 10.35 | Difference: -1.24
Layer:3 | Original Loss: 11.59 | Reconstructed Loss: 11.37 | Difference: -0.22
Layer:4 | Original Loss: 11.59 | Reconstructed Loss: 10.18 | Difference: -1.42


 57%|█████▋    | 79/138 [02:23<01:47,  1.83s/it]

Layer:5 | Original Loss: 11.59 | Reconstructed Loss: 17.06 | Difference: 5.46
Layer:1 | Original Loss: 11.59 | Reconstructed Loss: 10.02 | Difference: -1.56
Layer:2 | Original Loss: 11.59 | Reconstructed Loss: 10.34 | Difference: -1.24
Layer:3 | Original Loss: 11.59 | Reconstructed Loss: 11.36 | Difference: -0.22
Layer:4 | Original Loss: 11.59 | Reconstructed Loss: 10.17 | Difference: -1.42


 58%|█████▊    | 80/138 [02:25<01:46,  1.83s/it]

Layer:5 | Original Loss: 11.59 | Reconstructed Loss: 17.05 | Difference: 5.47
Layer:1 | Original Loss: 11.58 | Reconstructed Loss: 10.01 | Difference: -1.56
Layer:2 | Original Loss: 11.58 | Reconstructed Loss: 10.33 | Difference: -1.24
Layer:3 | Original Loss: 11.58 | Reconstructed Loss: 11.35 | Difference: -0.22
Layer:4 | Original Loss: 11.58 | Reconstructed Loss: 10.16 | Difference: -1.42


 59%|█████▊    | 81/138 [02:26<01:44,  1.83s/it]

Layer:5 | Original Loss: 11.58 | Reconstructed Loss: 17.04 | Difference: 5.47
Layer:1 | Original Loss: 11.57 | Reconstructed Loss: 10.01 | Difference: -1.56
Layer:2 | Original Loss: 11.57 | Reconstructed Loss: 10.33 | Difference: -1.24
Layer:3 | Original Loss: 11.57 | Reconstructed Loss: 11.35 | Difference: -0.22
Layer:4 | Original Loss: 11.57 | Reconstructed Loss: 10.15 | Difference: -1.42


 59%|█████▉    | 82/138 [02:28<01:42,  1.83s/it]

Layer:5 | Original Loss: 11.57 | Reconstructed Loss: 17.04 | Difference: 5.47
Layer:1 | Original Loss: 11.56 | Reconstructed Loss: 10.00 | Difference: -1.56
Layer:2 | Original Loss: 11.56 | Reconstructed Loss: 10.32 | Difference: -1.24
Layer:3 | Original Loss: 11.56 | Reconstructed Loss: 11.34 | Difference: -0.22
Layer:4 | Original Loss: 11.56 | Reconstructed Loss: 10.15 | Difference: -1.42


 60%|██████    | 83/138 [02:30<01:40,  1.83s/it]

Layer:5 | Original Loss: 11.56 | Reconstructed Loss: 17.03 | Difference: 5.47
Layer:1 | Original Loss: 11.56 | Reconstructed Loss: 10.00 | Difference: -1.56
Layer:2 | Original Loss: 11.56 | Reconstructed Loss: 10.32 | Difference: -1.24
Layer:3 | Original Loss: 11.56 | Reconstructed Loss: 11.34 | Difference: -0.22
Layer:4 | Original Loss: 11.56 | Reconstructed Loss: 10.14 | Difference: -1.42


 61%|██████    | 84/138 [02:32<01:39,  1.83s/it]

Layer:5 | Original Loss: 11.56 | Reconstructed Loss: 17.03 | Difference: 5.47
Layer:1 | Original Loss: 11.56 | Reconstructed Loss: 10.00 | Difference: -1.56
Layer:2 | Original Loss: 11.56 | Reconstructed Loss: 10.32 | Difference: -1.24
Layer:3 | Original Loss: 11.56 | Reconstructed Loss: 11.33 | Difference: -0.22
Layer:4 | Original Loss: 11.56 | Reconstructed Loss: 10.14 | Difference: -1.42


 62%|██████▏   | 85/138 [02:34<01:37,  1.84s/it]

Layer:5 | Original Loss: 11.56 | Reconstructed Loss: 17.02 | Difference: 5.47
Layer:1 | Original Loss: 11.56 | Reconstructed Loss: 10.00 | Difference: -1.56
Layer:2 | Original Loss: 11.56 | Reconstructed Loss: 10.32 | Difference: -1.24
Layer:3 | Original Loss: 11.56 | Reconstructed Loss: 11.33 | Difference: -0.22
Layer:4 | Original Loss: 11.56 | Reconstructed Loss: 10.14 | Difference: -1.41


 62%|██████▏   | 86/138 [02:36<01:35,  1.83s/it]

Layer:5 | Original Loss: 11.56 | Reconstructed Loss: 17.02 | Difference: 5.47
Layer:1 | Original Loss: 11.55 | Reconstructed Loss: 9.99 | Difference: -1.56
Layer:2 | Original Loss: 11.55 | Reconstructed Loss: 10.31 | Difference: -1.24
Layer:3 | Original Loss: 11.55 | Reconstructed Loss: 11.33 | Difference: -0.22
Layer:4 | Original Loss: 11.55 | Reconstructed Loss: 10.14 | Difference: -1.41


 63%|██████▎   | 87/138 [02:37<01:33,  1.84s/it]

Layer:5 | Original Loss: 11.55 | Reconstructed Loss: 17.02 | Difference: 5.47
Layer:1 | Original Loss: 11.55 | Reconstructed Loss: 9.99 | Difference: -1.56
Layer:2 | Original Loss: 11.55 | Reconstructed Loss: 10.31 | Difference: -1.24
Layer:3 | Original Loss: 11.55 | Reconstructed Loss: 11.33 | Difference: -0.22
Layer:4 | Original Loss: 11.55 | Reconstructed Loss: 10.14 | Difference: -1.41


 64%|██████▍   | 88/138 [02:39<01:31,  1.84s/it]

Layer:5 | Original Loss: 11.55 | Reconstructed Loss: 17.02 | Difference: 5.46
Layer:1 | Original Loss: 11.55 | Reconstructed Loss: 9.99 | Difference: -1.56
Layer:2 | Original Loss: 11.55 | Reconstructed Loss: 10.31 | Difference: -1.24
Layer:3 | Original Loss: 11.55 | Reconstructed Loss: 11.33 | Difference: -0.22
Layer:4 | Original Loss: 11.55 | Reconstructed Loss: 10.13 | Difference: -1.41


 64%|██████▍   | 89/138 [02:41<01:30,  1.84s/it]

Layer:5 | Original Loss: 11.55 | Reconstructed Loss: 17.01 | Difference: 5.46
Layer:1 | Original Loss: 11.54 | Reconstructed Loss: 9.98 | Difference: -1.56
Layer:2 | Original Loss: 11.54 | Reconstructed Loss: 10.30 | Difference: -1.24
Layer:3 | Original Loss: 11.54 | Reconstructed Loss: 11.32 | Difference: -0.22
Layer:4 | Original Loss: 11.54 | Reconstructed Loss: 10.13 | Difference: -1.41


 65%|██████▌   | 90/138 [02:43<01:28,  1.84s/it]

Layer:5 | Original Loss: 11.54 | Reconstructed Loss: 17.00 | Difference: 5.47
Layer:1 | Original Loss: 11.53 | Reconstructed Loss: 9.97 | Difference: -1.56
Layer:2 | Original Loss: 11.53 | Reconstructed Loss: 10.29 | Difference: -1.24
Layer:3 | Original Loss: 11.53 | Reconstructed Loss: 11.31 | Difference: -0.22
Layer:4 | Original Loss: 11.53 | Reconstructed Loss: 10.12 | Difference: -1.41


 66%|██████▌   | 91/138 [02:45<01:26,  1.84s/it]

Layer:5 | Original Loss: 11.53 | Reconstructed Loss: 17.00 | Difference: 5.46
Layer:1 | Original Loss: 11.53 | Reconstructed Loss: 9.97 | Difference: -1.56
Layer:2 | Original Loss: 11.53 | Reconstructed Loss: 10.29 | Difference: -1.24
Layer:3 | Original Loss: 11.53 | Reconstructed Loss: 11.30 | Difference: -0.22
Layer:4 | Original Loss: 11.53 | Reconstructed Loss: 10.12 | Difference: -1.41


 67%|██████▋   | 92/138 [02:47<01:24,  1.84s/it]

Layer:5 | Original Loss: 11.53 | Reconstructed Loss: 16.99 | Difference: 5.46
Layer:1 | Original Loss: 11.52 | Reconstructed Loss: 9.97 | Difference: -1.56
Layer:2 | Original Loss: 11.52 | Reconstructed Loss: 10.29 | Difference: -1.24
Layer:3 | Original Loss: 11.52 | Reconstructed Loss: 11.30 | Difference: -0.22
Layer:4 | Original Loss: 11.52 | Reconstructed Loss: 10.11 | Difference: -1.41


 67%|██████▋   | 93/138 [02:48<01:22,  1.84s/it]

Layer:5 | Original Loss: 11.52 | Reconstructed Loss: 16.98 | Difference: 5.46
Layer:1 | Original Loss: 11.51 | Reconstructed Loss: 9.95 | Difference: -1.56
Layer:2 | Original Loss: 11.51 | Reconstructed Loss: 10.28 | Difference: -1.24
Layer:3 | Original Loss: 11.51 | Reconstructed Loss: 11.29 | Difference: -0.22
Layer:4 | Original Loss: 11.51 | Reconstructed Loss: 10.10 | Difference: -1.41


 68%|██████▊   | 94/138 [02:50<01:21,  1.84s/it]

Layer:5 | Original Loss: 11.51 | Reconstructed Loss: 16.98 | Difference: 5.46
Layer:1 | Original Loss: 11.50 | Reconstructed Loss: 9.94 | Difference: -1.56
Layer:2 | Original Loss: 11.50 | Reconstructed Loss: 10.27 | Difference: -1.24
Layer:3 | Original Loss: 11.50 | Reconstructed Loss: 11.28 | Difference: -0.22
Layer:4 | Original Loss: 11.50 | Reconstructed Loss: 10.09 | Difference: -1.41


 69%|██████▉   | 95/138 [02:52<01:19,  1.84s/it]

Layer:5 | Original Loss: 11.50 | Reconstructed Loss: 16.97 | Difference: 5.47
Layer:1 | Original Loss: 11.51 | Reconstructed Loss: 9.95 | Difference: -1.56
Layer:2 | Original Loss: 11.51 | Reconstructed Loss: 10.27 | Difference: -1.24
Layer:3 | Original Loss: 11.51 | Reconstructed Loss: 11.28 | Difference: -0.22
Layer:4 | Original Loss: 11.51 | Reconstructed Loss: 10.10 | Difference: -1.41


 70%|██████▉   | 96/138 [02:54<01:17,  1.84s/it]

Layer:5 | Original Loss: 11.51 | Reconstructed Loss: 16.97 | Difference: 5.47
Layer:1 | Original Loss: 11.51 | Reconstructed Loss: 9.95 | Difference: -1.56
Layer:2 | Original Loss: 11.51 | Reconstructed Loss: 10.28 | Difference: -1.24
Layer:3 | Original Loss: 11.51 | Reconstructed Loss: 11.29 | Difference: -0.22
Layer:4 | Original Loss: 11.51 | Reconstructed Loss: 10.10 | Difference: -1.41


 70%|███████   | 97/138 [02:56<01:15,  1.84s/it]

Layer:5 | Original Loss: 11.51 | Reconstructed Loss: 16.98 | Difference: 5.47
Layer:1 | Original Loss: 11.50 | Reconstructed Loss: 9.95 | Difference: -1.56
Layer:2 | Original Loss: 11.50 | Reconstructed Loss: 10.27 | Difference: -1.24
Layer:3 | Original Loss: 11.50 | Reconstructed Loss: 11.28 | Difference: -0.22
Layer:4 | Original Loss: 11.50 | Reconstructed Loss: 10.10 | Difference: -1.41


 71%|███████   | 98/138 [02:58<01:13,  1.85s/it]

Layer:5 | Original Loss: 11.50 | Reconstructed Loss: 16.97 | Difference: 5.46
Layer:1 | Original Loss: 11.51 | Reconstructed Loss: 9.95 | Difference: -1.56
Layer:2 | Original Loss: 11.51 | Reconstructed Loss: 10.27 | Difference: -1.23
Layer:3 | Original Loss: 11.51 | Reconstructed Loss: 11.29 | Difference: -0.22
Layer:4 | Original Loss: 11.51 | Reconstructed Loss: 10.10 | Difference: -1.41


 72%|███████▏  | 99/138 [03:00<01:11,  1.85s/it]

Layer:5 | Original Loss: 11.51 | Reconstructed Loss: 16.98 | Difference: 5.47
Layer:1 | Original Loss: 11.51 | Reconstructed Loss: 9.95 | Difference: -1.56
Layer:2 | Original Loss: 11.51 | Reconstructed Loss: 10.28 | Difference: -1.23
Layer:3 | Original Loss: 11.51 | Reconstructed Loss: 11.29 | Difference: -0.22
Layer:4 | Original Loss: 11.51 | Reconstructed Loss: 10.10 | Difference: -1.41


 72%|███████▏  | 100/138 [03:01<01:10,  1.85s/it]

Layer:5 | Original Loss: 11.51 | Reconstructed Loss: 16.98 | Difference: 5.47
Layer:1 | Original Loss: 11.51 | Reconstructed Loss: 9.95 | Difference: -1.56
Layer:2 | Original Loss: 11.51 | Reconstructed Loss: 10.27 | Difference: -1.23
Layer:3 | Original Loss: 11.51 | Reconstructed Loss: 11.29 | Difference: -0.22
Layer:4 | Original Loss: 11.51 | Reconstructed Loss: 10.10 | Difference: -1.41


 73%|███████▎  | 101/138 [03:03<01:08,  1.85s/it]

Layer:5 | Original Loss: 11.51 | Reconstructed Loss: 16.98 | Difference: 5.47
Layer:1 | Original Loss: 11.52 | Reconstructed Loss: 9.96 | Difference: -1.56
Layer:2 | Original Loss: 11.52 | Reconstructed Loss: 10.28 | Difference: -1.23
Layer:3 | Original Loss: 11.52 | Reconstructed Loss: 11.30 | Difference: -0.22
Layer:4 | Original Loss: 11.52 | Reconstructed Loss: 10.11 | Difference: -1.41


 74%|███████▍  | 102/138 [03:05<01:06,  1.85s/it]

Layer:5 | Original Loss: 11.52 | Reconstructed Loss: 16.99 | Difference: 5.47
Layer:1 | Original Loss: 11.51 | Reconstructed Loss: 9.95 | Difference: -1.56
Layer:2 | Original Loss: 11.51 | Reconstructed Loss: 10.28 | Difference: -1.23
Layer:3 | Original Loss: 11.51 | Reconstructed Loss: 11.30 | Difference: -0.22
Layer:4 | Original Loss: 11.51 | Reconstructed Loss: 10.11 | Difference: -1.41


 75%|███████▍  | 103/138 [03:07<01:04,  1.85s/it]

Layer:5 | Original Loss: 11.51 | Reconstructed Loss: 16.98 | Difference: 5.47
Layer:1 | Original Loss: 11.51 | Reconstructed Loss: 9.95 | Difference: -1.56
Layer:2 | Original Loss: 11.51 | Reconstructed Loss: 10.28 | Difference: -1.23
Layer:3 | Original Loss: 11.51 | Reconstructed Loss: 11.29 | Difference: -0.22
Layer:4 | Original Loss: 11.51 | Reconstructed Loss: 10.10 | Difference: -1.41


 75%|███████▌  | 104/138 [03:09<01:02,  1.85s/it]

Layer:5 | Original Loss: 11.51 | Reconstructed Loss: 16.98 | Difference: 5.47
Layer:1 | Original Loss: 11.51 | Reconstructed Loss: 9.95 | Difference: -1.56
Layer:2 | Original Loss: 11.51 | Reconstructed Loss: 10.28 | Difference: -1.23
Layer:3 | Original Loss: 11.51 | Reconstructed Loss: 11.29 | Difference: -0.22
Layer:4 | Original Loss: 11.51 | Reconstructed Loss: 10.10 | Difference: -1.41


 76%|███████▌  | 105/138 [03:11<01:01,  1.85s/it]

Layer:5 | Original Loss: 11.51 | Reconstructed Loss: 16.97 | Difference: 5.46
Layer:1 | Original Loss: 11.51 | Reconstructed Loss: 9.94 | Difference: -1.56
Layer:2 | Original Loss: 11.51 | Reconstructed Loss: 10.27 | Difference: -1.23
Layer:3 | Original Loss: 11.51 | Reconstructed Loss: 11.29 | Difference: -0.22
Layer:4 | Original Loss: 11.51 | Reconstructed Loss: 10.10 | Difference: -1.41


 77%|███████▋  | 106/138 [03:12<00:59,  1.85s/it]

Layer:5 | Original Loss: 11.51 | Reconstructed Loss: 16.97 | Difference: 5.46
Layer:1 | Original Loss: 11.50 | Reconstructed Loss: 9.94 | Difference: -1.56
Layer:2 | Original Loss: 11.50 | Reconstructed Loss: 10.27 | Difference: -1.23
Layer:3 | Original Loss: 11.50 | Reconstructed Loss: 11.28 | Difference: -0.22
Layer:4 | Original Loss: 11.50 | Reconstructed Loss: 10.09 | Difference: -1.41


 78%|███████▊  | 107/138 [03:14<00:57,  1.85s/it]

Layer:5 | Original Loss: 11.50 | Reconstructed Loss: 16.97 | Difference: 5.46
Layer:1 | Original Loss: 11.49 | Reconstructed Loss: 9.93 | Difference: -1.56
Layer:2 | Original Loss: 11.49 | Reconstructed Loss: 10.26 | Difference: -1.23
Layer:3 | Original Loss: 11.49 | Reconstructed Loss: 11.27 | Difference: -0.22
Layer:4 | Original Loss: 11.49 | Reconstructed Loss: 10.08 | Difference: -1.41


 78%|███████▊  | 108/138 [03:16<00:55,  1.85s/it]

Layer:5 | Original Loss: 11.49 | Reconstructed Loss: 16.95 | Difference: 5.46
Layer:1 | Original Loss: 11.49 | Reconstructed Loss: 9.93 | Difference: -1.56
Layer:2 | Original Loss: 11.49 | Reconstructed Loss: 10.26 | Difference: -1.23
Layer:3 | Original Loss: 11.49 | Reconstructed Loss: 11.27 | Difference: -0.22
Layer:4 | Original Loss: 11.49 | Reconstructed Loss: 10.08 | Difference: -1.41


 79%|███████▉  | 109/138 [03:18<00:53,  1.85s/it]

Layer:5 | Original Loss: 11.49 | Reconstructed Loss: 16.95 | Difference: 5.46
Layer:1 | Original Loss: 11.48 | Reconstructed Loss: 9.92 | Difference: -1.56
Layer:2 | Original Loss: 11.48 | Reconstructed Loss: 10.25 | Difference: -1.23
Layer:3 | Original Loss: 11.48 | Reconstructed Loss: 11.26 | Difference: -0.22
Layer:4 | Original Loss: 11.48 | Reconstructed Loss: 10.08 | Difference: -1.41


 80%|███████▉  | 110/138 [03:20<00:51,  1.85s/it]

Layer:5 | Original Loss: 11.48 | Reconstructed Loss: 16.94 | Difference: 5.46
Layer:1 | Original Loss: 11.48 | Reconstructed Loss: 9.92 | Difference: -1.56
Layer:2 | Original Loss: 11.48 | Reconstructed Loss: 10.25 | Difference: -1.23
Layer:3 | Original Loss: 11.48 | Reconstructed Loss: 11.26 | Difference: -0.22
Layer:4 | Original Loss: 11.48 | Reconstructed Loss: 10.07 | Difference: -1.41


 80%|████████  | 111/138 [03:22<00:50,  1.85s/it]

Layer:5 | Original Loss: 11.48 | Reconstructed Loss: 16.94 | Difference: 5.46
Layer:1 | Original Loss: 11.48 | Reconstructed Loss: 9.92 | Difference: -1.56
Layer:2 | Original Loss: 11.48 | Reconstructed Loss: 10.25 | Difference: -1.23
Layer:3 | Original Loss: 11.48 | Reconstructed Loss: 11.26 | Difference: -0.22
Layer:4 | Original Loss: 11.48 | Reconstructed Loss: 10.07 | Difference: -1.41


 81%|████████  | 112/138 [03:24<00:48,  1.85s/it]

Layer:5 | Original Loss: 11.48 | Reconstructed Loss: 16.94 | Difference: 5.46
Layer:1 | Original Loss: 11.47 | Reconstructed Loss: 9.91 | Difference: -1.56
Layer:2 | Original Loss: 11.47 | Reconstructed Loss: 10.24 | Difference: -1.23
Layer:3 | Original Loss: 11.47 | Reconstructed Loss: 11.25 | Difference: -0.22
Layer:4 | Original Loss: 11.47 | Reconstructed Loss: 10.07 | Difference: -1.41


 82%|████████▏ | 113/138 [03:25<00:46,  1.85s/it]

Layer:5 | Original Loss: 11.47 | Reconstructed Loss: 16.93 | Difference: 5.46
Layer:1 | Original Loss: 11.47 | Reconstructed Loss: 9.91 | Difference: -1.56
Layer:2 | Original Loss: 11.47 | Reconstructed Loss: 10.24 | Difference: -1.23
Layer:3 | Original Loss: 11.47 | Reconstructed Loss: 11.25 | Difference: -0.22
Layer:4 | Original Loss: 11.47 | Reconstructed Loss: 10.07 | Difference: -1.41


 83%|████████▎ | 114/138 [03:27<00:44,  1.85s/it]

Layer:5 | Original Loss: 11.47 | Reconstructed Loss: 16.93 | Difference: 5.46
Layer:1 | Original Loss: 11.46 | Reconstructed Loss: 9.90 | Difference: -1.56
Layer:2 | Original Loss: 11.46 | Reconstructed Loss: 10.23 | Difference: -1.23
Layer:3 | Original Loss: 11.46 | Reconstructed Loss: 11.24 | Difference: -0.22
Layer:4 | Original Loss: 11.46 | Reconstructed Loss: 10.05 | Difference: -1.40


 83%|████████▎ | 115/138 [03:29<00:42,  1.85s/it]

Layer:5 | Original Loss: 11.46 | Reconstructed Loss: 16.92 | Difference: 5.46
Layer:1 | Original Loss: 11.45 | Reconstructed Loss: 9.90 | Difference: -1.56
Layer:2 | Original Loss: 11.45 | Reconstructed Loss: 10.22 | Difference: -1.23
Layer:3 | Original Loss: 11.45 | Reconstructed Loss: 11.23 | Difference: -0.22
Layer:4 | Original Loss: 11.45 | Reconstructed Loss: 10.05 | Difference: -1.40


 84%|████████▍ | 116/138 [03:31<00:40,  1.85s/it]

Layer:5 | Original Loss: 11.45 | Reconstructed Loss: 16.91 | Difference: 5.46
Layer:1 | Original Loss: 11.45 | Reconstructed Loss: 9.89 | Difference: -1.56
Layer:2 | Original Loss: 11.45 | Reconstructed Loss: 10.22 | Difference: -1.23
Layer:3 | Original Loss: 11.45 | Reconstructed Loss: 11.23 | Difference: -0.22
Layer:4 | Original Loss: 11.45 | Reconstructed Loss: 10.05 | Difference: -1.40


 85%|████████▍ | 117/138 [03:33<00:38,  1.86s/it]

Layer:5 | Original Loss: 11.45 | Reconstructed Loss: 16.91 | Difference: 5.46
Layer:1 | Original Loss: 11.45 | Reconstructed Loss: 9.89 | Difference: -1.56
Layer:2 | Original Loss: 11.45 | Reconstructed Loss: 10.21 | Difference: -1.23
Layer:3 | Original Loss: 11.45 | Reconstructed Loss: 11.23 | Difference: -0.22
Layer:4 | Original Loss: 11.45 | Reconstructed Loss: 10.04 | Difference: -1.40


 86%|████████▌ | 118/138 [03:35<00:37,  1.86s/it]

Layer:5 | Original Loss: 11.45 | Reconstructed Loss: 16.90 | Difference: 5.46
Layer:1 | Original Loss: 11.44 | Reconstructed Loss: 9.88 | Difference: -1.56
Layer:2 | Original Loss: 11.44 | Reconstructed Loss: 10.21 | Difference: -1.23
Layer:3 | Original Loss: 11.44 | Reconstructed Loss: 11.22 | Difference: -0.22
Layer:4 | Original Loss: 11.44 | Reconstructed Loss: 10.04 | Difference: -1.40


 86%|████████▌ | 119/138 [03:37<00:35,  1.86s/it]

Layer:5 | Original Loss: 11.44 | Reconstructed Loss: 16.90 | Difference: 5.46
Layer:1 | Original Loss: 11.44 | Reconstructed Loss: 9.89 | Difference: -1.56
Layer:2 | Original Loss: 11.44 | Reconstructed Loss: 10.21 | Difference: -1.23
Layer:3 | Original Loss: 11.44 | Reconstructed Loss: 11.22 | Difference: -0.22
Layer:4 | Original Loss: 11.44 | Reconstructed Loss: 10.04 | Difference: -1.40


 87%|████████▋ | 120/138 [03:38<00:33,  1.86s/it]

Layer:5 | Original Loss: 11.44 | Reconstructed Loss: 16.90 | Difference: 5.46
Layer:1 | Original Loss: 11.44 | Reconstructed Loss: 9.88 | Difference: -1.56
Layer:2 | Original Loss: 11.44 | Reconstructed Loss: 10.21 | Difference: -1.23
Layer:3 | Original Loss: 11.44 | Reconstructed Loss: 11.22 | Difference: -0.22
Layer:4 | Original Loss: 11.44 | Reconstructed Loss: 10.03 | Difference: -1.40


 88%|████████▊ | 121/138 [03:40<00:31,  1.86s/it]

Layer:5 | Original Loss: 11.44 | Reconstructed Loss: 16.89 | Difference: 5.46
Layer:1 | Original Loss: 11.43 | Reconstructed Loss: 9.87 | Difference: -1.56
Layer:2 | Original Loss: 11.43 | Reconstructed Loss: 10.20 | Difference: -1.23
Layer:3 | Original Loss: 11.43 | Reconstructed Loss: 11.21 | Difference: -0.22
Layer:4 | Original Loss: 11.43 | Reconstructed Loss: 10.03 | Difference: -1.40


 88%|████████▊ | 122/138 [03:42<00:29,  1.86s/it]

Layer:5 | Original Loss: 11.43 | Reconstructed Loss: 16.89 | Difference: 5.46
Layer:1 | Original Loss: 11.42 | Reconstructed Loss: 9.87 | Difference: -1.56
Layer:2 | Original Loss: 11.42 | Reconstructed Loss: 10.19 | Difference: -1.23
Layer:3 | Original Loss: 11.42 | Reconstructed Loss: 11.20 | Difference: -0.22
Layer:4 | Original Loss: 11.42 | Reconstructed Loss: 10.02 | Difference: -1.40


 89%|████████▉ | 123/138 [03:44<00:27,  1.86s/it]

Layer:5 | Original Loss: 11.42 | Reconstructed Loss: 16.88 | Difference: 5.46
Layer:1 | Original Loss: 11.41 | Reconstructed Loss: 9.86 | Difference: -1.56
Layer:2 | Original Loss: 11.41 | Reconstructed Loss: 10.18 | Difference: -1.23
Layer:3 | Original Loss: 11.41 | Reconstructed Loss: 11.19 | Difference: -0.22
Layer:4 | Original Loss: 11.41 | Reconstructed Loss: 10.01 | Difference: -1.40


 90%|████████▉ | 124/138 [03:46<00:26,  1.86s/it]

Layer:5 | Original Loss: 11.41 | Reconstructed Loss: 16.87 | Difference: 5.46
Layer:1 | Original Loss: 11.41 | Reconstructed Loss: 9.85 | Difference: -1.56
Layer:2 | Original Loss: 11.41 | Reconstructed Loss: 10.18 | Difference: -1.23
Layer:3 | Original Loss: 11.41 | Reconstructed Loss: 11.19 | Difference: -0.22
Layer:4 | Original Loss: 11.41 | Reconstructed Loss: 10.01 | Difference: -1.40


 91%|█████████ | 125/138 [03:48<00:24,  1.86s/it]

Layer:5 | Original Loss: 11.41 | Reconstructed Loss: 16.87 | Difference: 5.46
Layer:1 | Original Loss: 11.40 | Reconstructed Loss: 9.85 | Difference: -1.56
Layer:2 | Original Loss: 11.40 | Reconstructed Loss: 10.17 | Difference: -1.23
Layer:3 | Original Loss: 11.40 | Reconstructed Loss: 11.18 | Difference: -0.22
Layer:4 | Original Loss: 11.40 | Reconstructed Loss: 10.00 | Difference: -1.40


 91%|█████████▏| 126/138 [03:50<00:22,  1.86s/it]

Layer:5 | Original Loss: 11.40 | Reconstructed Loss: 16.86 | Difference: 5.46
Layer:1 | Original Loss: 11.42 | Reconstructed Loss: 9.86 | Difference: -1.56
Layer:2 | Original Loss: 11.42 | Reconstructed Loss: 10.19 | Difference: -1.23
Layer:3 | Original Loss: 11.42 | Reconstructed Loss: 11.20 | Difference: -0.22
Layer:4 | Original Loss: 11.42 | Reconstructed Loss: 10.02 | Difference: -1.40


 92%|█████████▏| 127/138 [03:51<00:20,  1.86s/it]

Layer:5 | Original Loss: 11.42 | Reconstructed Loss: 16.88 | Difference: 5.46
Layer:1 | Original Loss: 11.41 | Reconstructed Loss: 9.85 | Difference: -1.56
Layer:2 | Original Loss: 11.41 | Reconstructed Loss: 10.18 | Difference: -1.23
Layer:3 | Original Loss: 11.41 | Reconstructed Loss: 11.19 | Difference: -0.22
Layer:4 | Original Loss: 11.41 | Reconstructed Loss: 10.01 | Difference: -1.40


 93%|█████████▎| 128/138 [03:53<00:18,  1.86s/it]

Layer:5 | Original Loss: 11.41 | Reconstructed Loss: 16.87 | Difference: 5.46
Layer:1 | Original Loss: 11.42 | Reconstructed Loss: 9.86 | Difference: -1.56
Layer:2 | Original Loss: 11.42 | Reconstructed Loss: 10.19 | Difference: -1.23
Layer:3 | Original Loss: 11.42 | Reconstructed Loss: 11.20 | Difference: -0.22
Layer:4 | Original Loss: 11.42 | Reconstructed Loss: 10.02 | Difference: -1.40


 93%|█████████▎| 129/138 [03:55<00:16,  1.86s/it]

Layer:5 | Original Loss: 11.42 | Reconstructed Loss: 16.88 | Difference: 5.46
Layer:1 | Original Loss: 11.42 | Reconstructed Loss: 9.86 | Difference: -1.56
Layer:2 | Original Loss: 11.42 | Reconstructed Loss: 10.19 | Difference: -1.23
Layer:3 | Original Loss: 11.42 | Reconstructed Loss: 11.20 | Difference: -0.22
Layer:4 | Original Loss: 11.42 | Reconstructed Loss: 10.01 | Difference: -1.40


 94%|█████████▍| 130/138 [03:57<00:14,  1.86s/it]

Layer:5 | Original Loss: 11.42 | Reconstructed Loss: 16.88 | Difference: 5.46
Layer:1 | Original Loss: 11.42 | Reconstructed Loss: 9.86 | Difference: -1.56
Layer:2 | Original Loss: 11.42 | Reconstructed Loss: 10.19 | Difference: -1.23
Layer:3 | Original Loss: 11.42 | Reconstructed Loss: 11.20 | Difference: -0.22
Layer:4 | Original Loss: 11.42 | Reconstructed Loss: 10.01 | Difference: -1.41


 95%|█████████▍| 131/138 [03:59<00:13,  1.86s/it]

Layer:5 | Original Loss: 11.42 | Reconstructed Loss: 16.88 | Difference: 5.46
Layer:1 | Original Loss: 11.44 | Reconstructed Loss: 9.89 | Difference: -1.56
Layer:2 | Original Loss: 11.44 | Reconstructed Loss: 10.21 | Difference: -1.23
Layer:3 | Original Loss: 11.44 | Reconstructed Loss: 11.22 | Difference: -0.22
Layer:4 | Original Loss: 11.44 | Reconstructed Loss: 10.03 | Difference: -1.41


 96%|█████████▌| 132/138 [04:01<00:11,  1.86s/it]

Layer:5 | Original Loss: 11.44 | Reconstructed Loss: 16.90 | Difference: 5.45
Layer:1 | Original Loss: 11.44 | Reconstructed Loss: 9.88 | Difference: -1.56
Layer:2 | Original Loss: 11.44 | Reconstructed Loss: 10.21 | Difference: -1.23
Layer:3 | Original Loss: 11.44 | Reconstructed Loss: 11.22 | Difference: -0.22
Layer:4 | Original Loss: 11.44 | Reconstructed Loss: 10.03 | Difference: -1.41


 96%|█████████▋| 133/138 [04:03<00:09,  1.86s/it]

Layer:5 | Original Loss: 11.44 | Reconstructed Loss: 16.89 | Difference: 5.45
Layer:1 | Original Loss: 11.44 | Reconstructed Loss: 9.88 | Difference: -1.56
Layer:2 | Original Loss: 11.44 | Reconstructed Loss: 10.21 | Difference: -1.23
Layer:3 | Original Loss: 11.44 | Reconstructed Loss: 11.21 | Difference: -0.22
Layer:4 | Original Loss: 11.44 | Reconstructed Loss: 10.03 | Difference: -1.41


 97%|█████████▋| 134/138 [04:04<00:07,  1.86s/it]

Layer:5 | Original Loss: 11.44 | Reconstructed Loss: 16.89 | Difference: 5.45
Layer:1 | Original Loss: 11.44 | Reconstructed Loss: 9.88 | Difference: -1.56
Layer:2 | Original Loss: 11.44 | Reconstructed Loss: 10.21 | Difference: -1.23
Layer:3 | Original Loss: 11.44 | Reconstructed Loss: 11.21 | Difference: -0.22
Layer:4 | Original Loss: 11.44 | Reconstructed Loss: 10.03 | Difference: -1.41


 98%|█████████▊| 135/138 [04:06<00:05,  1.86s/it]

Layer:5 | Original Loss: 11.44 | Reconstructed Loss: 16.89 | Difference: 5.45
Layer:1 | Original Loss: 11.44 | Reconstructed Loss: 9.88 | Difference: -1.56
Layer:2 | Original Loss: 11.44 | Reconstructed Loss: 10.21 | Difference: -1.23
Layer:3 | Original Loss: 11.44 | Reconstructed Loss: 11.22 | Difference: -0.22
Layer:4 | Original Loss: 11.44 | Reconstructed Loss: 10.03 | Difference: -1.41


 99%|█████████▊| 136/138 [04:08<00:03,  1.86s/it]

Layer:5 | Original Loss: 11.44 | Reconstructed Loss: 16.89 | Difference: 5.45
Layer:1 | Original Loss: 11.44 | Reconstructed Loss: 9.88 | Difference: -1.56
Layer:2 | Original Loss: 11.44 | Reconstructed Loss: 10.21 | Difference: -1.23
Layer:3 | Original Loss: 11.44 | Reconstructed Loss: 11.21 | Difference: -0.22
Layer:4 | Original Loss: 11.44 | Reconstructed Loss: 10.03 | Difference: -1.41


100%|██████████| 138/138 [04:10<00:00,  1.82s/it]

Layer:5 | Original Loss: 11.44 | Reconstructed Loss: 16.89 | Difference: 5.45
Layer:1 | Original Loss: 11.41 | Reconstructed Loss: 9.86 | Difference: -1.55
Layer:2 | Original Loss: 11.41 | Reconstructed Loss: 10.19 | Difference: -1.23
Layer:3 | Original Loss: 11.41 | Reconstructed Loss: 11.19 | Difference: -0.22
Layer:4 | Original Loss: 11.41 | Reconstructed Loss: 10.01 | Difference: -1.40
Layer:5 | Original Loss: 11.41 | Reconstructed Loss: 16.85 | Difference: 5.44





TypeError: unsupported operand type(s) for /=: 'list' and 'int'

In [50]:
[o_loss/(i+1) for o_loss in original_loss], [r_loss/(i+1) for r_loss in reconstructed_loss], [l_diff/(i+1) for l_diff in loss_diff]

([tensor(11.4106, device='cuda:0'),
  tensor(11.4106, device='cuda:0'),
  tensor(11.4106, device='cuda:0'),
  tensor(11.4106, device='cuda:0'),
  tensor(11.4106, device='cuda:0')],
 [tensor(9.8566, device='cuda:0'),
  tensor(10.1851, device='cuda:0'),
  tensor(11.1898, device='cuda:0'),
  tensor(10.0089, device='cuda:0'),
  tensor(16.8508, device='cuda:0')],
 [tensor(-1.5540, device='cuda:0'),
  tensor(-1.2255, device='cuda:0'),
  tensor(-0.2208, device='cuda:0'),
  tensor(-1.4017, device='cuda:0'),
  tensor(5.4402, device='cuda:0')])

In [None]:


def train_sparse_model_and_check_CE(model, dataset, device, sparse_weights, autoencoders, cache_names, optimizers, test_datasest=None):
    mse = nn.MSELoss()
    losses = [[] for _ in range(len(sparse_weights))]

    for i, batch in enumerate(tqdm(dataset)):
        batch = batch["input_ids"].to(device)
        with torch.no_grad():
            with TraceDict(model, cache_names) as ret:
                _ = model(batch)
        
        for cache_name_ind in range(len(cache_names) // 2):
            sparse_weight = sparse_weights[cache_name_ind]

            input_cache_name = cache_names[cache_name_ind * 2]
            output_cache_name = cache_names[cache_name_ind * 2 + 1]
            input_autoencoder = autoencoders[cache_name_ind * 2]
            output_autoencoder = autoencoders[cache_name_ind * 2 + 1]
            input_activations = ret[input_cache_name].output
            output_activations = ret[output_cache_name].output

            if isinstance(input_activations, tuple):
                input_activations = input_activations[0]
            if isinstance(output_activations, tuple):
                output_activations = output_activations[0]

            input_internal_activations = rearrange(input_activations, "b s n -> (b s) n")
            output_internal_activations = rearrange(output_activations, "b s n -> (b s) n")

            input_internal_activations = input_autoencoder.encode(input_internal_activations)
            output_internal_activations = output_autoencoder.encode(output_internal_activations)

            x_hat = sparse_weights[cache_name_ind](input_internal_activations)

            # loss = mse(x_hat, output_internal_activations)
            current_alive_features_ind = alive_features_ind[cache_name_ind].nonzero(as_tuple=True)[0]
            indexed_x_hat = x_hat.index_select(dim=1, index=current_alive_features_ind.to(device))
            indexed_output_internal_activations = output_internal_activations.index_select(dim=1, index=current_alive_features_ind.to(device))
            loss = mse(indexed_x_hat, indexed_output_internal_activations)
            l1_loss = torch.norm(sparse_weight.linear.weight, dim=1, p=1).mean()
            # loss += l1_alpha * l1_loss

            loss.backward()
            optimizers[cache_name_ind].step()
            optimizers[cache_name_ind].zero_grad()
            losses[cache_name_ind].append(loss.item())

            if i % 100 == 0:
                print(f"Layers: {cache_name_ind} | MSE Loss: {loss.item()} | L1 Loss: {l1_loss.item()}")
            
        if i > 3000:
            # Stop training after 1000 batches
            break
    return losses

mlp_losses = train_sparse_model(model, dataset, device, mlp_weights, autoencoders, cache_names, mlp_optimizers)
