# Stochastic Parameter Decomposition

## Imports

In [None]:
# import all the torch stuff, as well as the ViT stuff and such
# import plotly.express as px
# import torch
# from jaxtyping import Int, Float
# from typing import List, Optional, Tuple
# from tqdm import tqdm
# from transformer_lens.hook_points import HookPoint
# from transformer_lens import utils, HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
# import circuitsvis as cv

In [11]:
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import einops
import typing
from tqdm import tqdm

## Toy Model MLP

In [1]:
# create a SPDHookedTransformer Class
# that extends the hookedTransformer
# which inherits all the previous class stuff
# but also contains the SPD training algo?

# how is everything actually structured? 
"""
Maybe I should start with implementing it on a simple MLP setup. 

Orig network:
- 1 layer MLP 5-2-5 (defined as usual with torch.Sequential presumably)

SPD network:
let's give it 10 subcomponents per layer
then it's just like 10 matmuls? 
maybe i should define it like ant did in the toy models of superposition paper
stick them all into a trenchcoat
this might be easier once I have defined the toy modle

"""

In [55]:
config = {
    "num_layers": 1,
    "pre_embed_size": 100,
    "in_size": 1000,
    "hidden_size": 50,
    "subcomponents_per_layer": 10, 
    "beta_1": 1, 
    "beta_2": 1, 
    "beta_3": 1, 
    "causal_imp_min": 1, 
    "num_mask_samples": 20,
    "importance_mlp_size": 10,
}

In [77]:
class ToyResidMLP(nn.Module):
    def __init__(self, config, device="mps"):
        super().__init__()
        # Initialize Weights for the
        self.num_layers, self.pre_embed_size, self.in_size, self.hidden_size = config["num_layers"], config["pre_embed_size"], config["in_size"], config["hidden_size"]
        self.device = device
        self.W_embed = nn.Parameter(torch.empty((self.pre_embed_size, self.in_size)))
        self.W_unembed = nn.Parameter(torch.empty((self.in_size, self.pre_embed_size)))
        self.W_in = nn.ParameterList([torch.empty((self.in_size, self.hidden_size), device=device) for i in range(self.num_layers)])
        self.W_out = nn.ParameterList([torch.empty((self.hidden_size, self.in_size), device=device) for i in range(self.num_layers)])
        self.b = nn.ParameterList([torch.zeros((self.hidden_size,), device=device) for i in range(self.num_layers)])

        for param in [self.W_embed, self.W_unembed] + list(self.W_in) + list(self.W_out): 
            nn.init.xavier_normal_(param)
        
    def forward(self, x): 
        
        assert x.shape[1] == self.pre_embed_size, f"Input shape {x.shape[0]} does not match model's accepted size {self.pre_embed_size}"
        # embed 
        x_resid = torch.einsum("np,pi->ni", x.clone(), self.W_embed)
        N, D = x_resid.shape

        
        for l in range(self.num_layers):
            hidden = F.relu(torch.einsum("nd,dh -> nh", x_resid, self.W_in[l]) + self.b[l])
            layer_out = torch.einsum("nh,hd -> nd", hidden, self.W_out[l])
            x_resid = x_resid + layer_out
        # am I supposed to have a embed and out?
        x_out = torch.einsum("ni,ip->np", x_resid, self.W_unembed) 
        return x_out


# Generated by LLM bc I am lazy and this is not important to me learning stuff atm
class SparseAutoencoderDataset(Dataset):
    """
    Dataset for learning to reconstruct sparse inputs.
    Each item is (input, target) where target is the input itself (or ReLU of input).
    """
    def __init__(self, in_dim=100, n_samples=10000, sparsity=0.9, device="mps"):
        super().__init__()
        self.in_dim = in_dim
        self.n_samples = n_samples
        self.device = device

        # Pre-generate all samples
        self.inputs = []
        self.targets = []
        
        for _ in range(n_samples):
            # Sparse input: each entry is -1, 0, or 1, with sparsity
            x = np.random.choice([0, -1, 1], size=(in_dim,), p=[1-sparsity, sparsity/2, sparsity/2])
            x = torch.tensor(x, dtype=torch.float32, device=device)
            target = F.relu(x)
            
            self.inputs.append(x)
            self.targets.append(target)
    
    def __len__(self):
        return self.n_samples
    
    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx]

def train_toy_resid_mlp(
    model,
    dataloader,
    lr=1e-3,
    num_epochs=10,
    device="mps",
    print_every=1
):
    model.train()
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss()
    for epoch in range(num_epochs):
        total_loss = 0.0
        for x, y in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * x.size(0)
        avg_loss = total_loss / len(dataloader.dataset)
        if (epoch+1) % print_every == 0:
            print(f"Epoch {epoch+1}: avg MSE loss = {avg_loss:.6f}")

In [78]:
## LLM-Generated Usage Example

torch.autograd.set_detect_anomaly(True)

if __name__ == "__main__":
    # Config
    device = "cuda" if torch.cuda.is_available() else "cpu"
    # Dataset and DataLoader
    # Dataset and DataLoader
    dataset = SparseAutoencoderDataset(
        in_dim=100,
        n_samples=10000,
        sparsity=0.9,
        device=device,
    )
    dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
    # Model
    model = ToyResidMLP(config, device=device)
    # Train
    train_toy_resid_mlp(model, dataloader, lr=8e-2, num_epochs=20, device=device)

Epoch 1/20:  13%|████████████▊                                                                                        | 10/79 [00:00<00:00, 98.47it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128

Epoch 1/20:  41%|████████████████████████████████████████▌                                                           | 32/79 [00:00<00:00, 102.83it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128

Epoch 1/20:  67%|███████████████████████████████████████████████████████████████████▊                                 | 53/79 [00:00<00:00, 97.80it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128

Epoch 1/20: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 99.71it/s]


x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128

Epoch 2/20:   0%|                                                                                                              | 0/79 [00:00<?, ?it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 2/20:  13%|████████████▊                                                                                        | 10/79 [00:00<00:00, 95.81it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128

Epoch 2/20:  25%|█████████████████████████▌                                                                           | 20/79 [00:00<00:00, 95.35it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 2/20:  38%|██████████████████████████████████████▎                                                              | 30/79 [00:00<00:00, 96.46it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128

Epoch 2/20:  51%|███████████████████████████████████████████████████▏                                                 | 40/79 [00:00<00:00, 96.59it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 2/20:  63%|███████████████████████████████████████████████████████████████▉                                     | 50/79 [00:00<00:00, 95.87it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128

Epoch 2/20:  77%|█████████████████████████████████████████████████████████████████████████████▉                       | 61/79 [00:00<00:00, 99.37it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 2/20: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 99.63it/s]


x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128

Epoch 3/20:   0%|                                                                                                              | 0/79 [00:00<?, ?it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 3/20:  14%|█████████████▉                                                                                      | 11/79 [00:00<00:00, 102.83it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128

Epoch 3/20:  28%|███████████████████████████▊                                                                        | 22/79 [00:00<00:00, 104.99it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 3/20:  42%|█████████████████████████████████████████▊                                                          | 33/79 [00:00<00:00, 102.34it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128

Epoch 3/20:  56%|███████████████████████████████████████████████████████▋                                            | 44/79 [00:00<00:00, 100.36it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 3/20:  70%|██████████████████████████████████████████████████████████████████████▎                              | 55/79 [00:00<00:00, 99.64it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128

Epoch 3/20:  82%|███████████████████████████████████████████████████████████████████████████████████                  | 65/79 [00:00<00:00, 98.32it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 3/20: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 99.14it/s]


x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([16, 100])
x out shape torch.Size([16, 100])
Epoch 3: avg MSE loss = 59.208182


Epoch 4/20:   0%|                                                                                                              | 0/79 [00:00<?, ?it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 4/20:  13%|████████████▊                                                                                        | 10/79 [00:00<00:00, 99.13it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 4/20:  27%|██████████████████████████▌                                                                         | 21/79 [00:00<00:00, 101.30it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 4/20:  41%|████████████████████████████████████████▉                                                            | 32/79 [00:00<00:00, 95.72it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128

Epoch 4/20:  54%|██████████████████████████████████████████████████████▉                                              | 43/79 [00:00<00:00, 98.73it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 4/20:  68%|█████████████████████████████████████████████████████████████████████                                | 54/79 [00:00<00:00, 99.85it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128

Epoch 4/20:  82%|██████████████████████████████████████████████████████████████████████████████████▎                 | 65/79 [00:00<00:00, 101.92it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 4/20: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 101.57it/s]


x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([16, 100])
x out shape torch.Size([16, 100])
Epoch 4: avg MSE loss = 29.398272


Epoch 5/20:   0%|                                                                                                              | 0/79 [00:00<?, ?it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 5/20:  14%|█████████████▉                                                                                      | 11/79 [00:00<00:00, 105.49it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 5/20:  28%|███████████████████████████▊                                                                        | 22/79 [00:00<00:00, 105.43it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 5/20:  42%|█████████████████████████████████████████▊                                                          | 33/79 [00:00<00:00, 105.13it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 5/20:  56%|███████████████████████████████████████████████████████▋                                            | 44/79 [00:00<00:00, 105.87it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 5/20:  70%|█████████████████████████████████████████████████████████████████████▌                              | 55/79 [00:00<00:00, 105.49it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 5/20:  84%|███████████████████████████████████████████████████████████████████████████████████▌                | 66/79 [00:00<00:00, 106.14it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 5/20: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 105.85it/s]


x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([16, 100])
x out shape torch.Size([16, 100])
Epoch 5: avg MSE loss = 16.977402


Epoch 6/20:   0%|                                                                                                              | 0/79 [00:00<?, ?it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 6/20:  14%|█████████████▉                                                                                      | 11/79 [00:00<00:00, 106.26it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 6/20:  28%|███████████████████████████▊                                                                        | 22/79 [00:00<00:00, 105.98it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 6/20:  42%|█████████████████████████████████████████▊                                                          | 33/79 [00:00<00:00, 106.18it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 6/20:  56%|███████████████████████████████████████████████████████▋                                            | 44/79 [00:00<00:00, 106.01it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 6/20:  70%|█████████████████████████████████████████████████████████████████████▌                              | 55/79 [00:00<00:00, 106.23it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 6/20:  84%|███████████████████████████████████████████████████████████████████████████████████▌                | 66/79 [00:00<00:00, 106.14it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 6/20: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 106.22it/s]


x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([16, 100])
x out shape torch.Size([16, 100])
Epoch 6: avg MSE loss = 10.628971


Epoch 7/20:   0%|                                                                                                              | 0/79 [00:00<?, ?it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 7/20:  11%|███████████▌                                                                                          | 9/79 [00:00<00:00, 72.86it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 7/20:  24%|████████████████████████▎                                                                            | 19/79 [00:00<00:00, 84.36it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 7/20:  38%|██████████████████████████████████████▎                                                              | 30/79 [00:00<00:00, 93.99it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128

Epoch 7/20:  52%|████████████████████████████████████████████████████▍                                                | 41/79 [00:00<00:00, 98.48it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 7/20:  66%|█████████████████████████████████████████████████████████████████▊                                  | 52/79 [00:00<00:00, 101.98it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128

Epoch 7/20:  80%|███████████████████████████████████████████████████████████████████████████████▋                    | 63/79 [00:00<00:00, 103.11it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 7/20:  94%|█████████████████████████████████████████████████████████████████████████████████████████████▋      | 74/79 [00:00<00:00, 104.28it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 7/20: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 99.65it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([16, 100])
x out shape torch.Size([16, 100])





Epoch 7: avg MSE loss = 7.007546


Epoch 8/20:   0%|                                                                                                              | 0/79 [00:00<?, ?it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 8/20:  14%|█████████████▉                                                                                      | 11/79 [00:00<00:00, 104.94it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 8/20:  28%|███████████████████████████▊                                                                        | 22/79 [00:00<00:00, 105.50it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 8/20:  42%|█████████████████████████████████████████▊                                                          | 33/79 [00:00<00:00, 105.79it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 8/20:  56%|███████████████████████████████████████████████████████▋                                            | 44/79 [00:00<00:00, 106.22it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 8/20:  70%|█████████████████████████████████████████████████████████████████████▌                              | 55/79 [00:00<00:00, 106.38it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 8/20:  84%|███████████████████████████████████████████████████████████████████████████████████▌                | 66/79 [00:00<00:00, 106.24it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 8/20: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 105.02it/s]


x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([16, 100])
x out shape torch.Size([16, 100])
Epoch 8: avg MSE loss = 4.821876


Epoch 9/20:   0%|                                                                                                              | 0/79 [00:00<?, ?it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 9/20:  13%|████████████▊                                                                                        | 10/79 [00:00<00:00, 95.17it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 9/20:  27%|██████████████████████████▊                                                                          | 21/79 [00:00<00:00, 98.05it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 9/20:  39%|███████████████████████████████████████▋                                                             | 31/79 [00:00<00:00, 96.40it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 9/20:  52%|████████████████████████████████████████████████████▍                                                | 41/79 [00:00<00:00, 96.26it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 9/20:  65%|█████████████████████████████████████████████████████████████████▏                                   | 51/79 [00:00<00:00, 96.41it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 9/20:  77%|█████████████████████████████████████████████████████████████████████████████▉                       | 61/79 [00:00<00:00, 96.87it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 9/20:  91%|████████████████████████████████████████████████████████████████████████████████████████████         | 72/79 [00:00<00:00, 99.16it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 9/20: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 98.57it/s]


x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([16, 100])
x out shape torch.Size([16, 100])
Epoch 9: avg MSE loss = 3.420362


Epoch 10/20:   0%|                                                                                                             | 0/79 [00:00<?, ?it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 10/20:  14%|█████████████▊                                                                                     | 11/79 [00:00<00:00, 103.91it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 10/20:  28%|███████████████████████████▊                                                                        | 22/79 [00:00<00:00, 99.34it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 10/20:  42%|█████████████████████████████████████████▎                                                         | 33/79 [00:00<00:00, 101.28it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 10/20:  56%|███████████████████████████████████████████████████████▏                                           | 44/79 [00:00<00:00, 102.27it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 10/20:  70%|████████████████████████████████████████████████████████████████████▉                              | 55/79 [00:00<00:00, 102.34it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 10/20:  84%|██████████████████████████████████████████████████████████████████████████████████▋                | 66/79 [00:00<00:00, 102.36it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 10/20:  97%|████████████████████████████████████████████████████████████████████████████████████████████████▍  | 77/79 [00:00<00:00, 102.43it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([16, 100])


Epoch 10/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 102.19it/s]


x out shape torch.Size([16, 100])
Epoch 10: avg MSE loss = 2.491114


Epoch 11/20:   0%|                                                                                                             | 0/79 [00:00<?, ?it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 11/20:  14%|█████████████▊                                                                                     | 11/79 [00:00<00:00, 103.97it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 11/20:  28%|███████████████████████████▌                                                                       | 22/79 [00:00<00:00, 104.14it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 11/20:  42%|█████████████████████████████████████████▎                                                         | 33/79 [00:00<00:00, 104.63it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 11/20:  56%|███████████████████████████████████████████████████████▏                                           | 44/79 [00:00<00:00, 104.99it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 11/20:  70%|████████████████████████████████████████████████████████████████████▉                              | 55/79 [00:00<00:00, 105.08it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 11/20:  84%|██████████████████████████████████████████████████████████████████████████████████▋                | 66/79 [00:00<00:00, 105.30it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 11/20:  97%|████████████████████████████████████████████████████████████████████████████████████████████████▍  | 77/79 [00:00<00:00, 104.96it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([16, 100])
x out shape torch.Size([16, 100])


Epoch 11/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 105.03it/s]


Epoch 11: avg MSE loss = 1.860358


Epoch 12/20:   0%|                                                                                                             | 0/79 [00:00<?, ?it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 12/20:  14%|█████████████▊                                                                                     | 11/79 [00:00<00:00, 103.63it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 12/20:  28%|███████████████████████████▌                                                                       | 22/79 [00:00<00:00, 104.02it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 12/20:  42%|█████████████████████████████████████████▎                                                         | 33/79 [00:00<00:00, 104.95it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 12/20:  56%|███████████████████████████████████████████████████████▏                                           | 44/79 [00:00<00:00, 105.13it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 12/20:  70%|████████████████████████████████████████████████████████████████████▉                              | 55/79 [00:00<00:00, 105.45it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 12/20:  84%|██████████████████████████████████████████████████████████████████████████████████▋                | 66/79 [00:00<00:00, 105.53it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 12/20:  97%|████████████████████████████████████████████████████████████████████████████████████████████████▍  | 77/79 [00:00<00:00, 105.52it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 12/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 105.29it/s]


x input shape torch.Size([16, 100])
x out shape torch.Size([16, 100])
Epoch 12: avg MSE loss = 1.423939


Epoch 13/20:   0%|                                                                                                             | 0/79 [00:00<?, ?it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 13/20:  14%|█████████████▊                                                                                     | 11/79 [00:00<00:00, 102.55it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 13/20:  28%|███████████████████████████▌                                                                       | 22/79 [00:00<00:00, 101.63it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 13/20:  42%|█████████████████████████████████████████▎                                                         | 33/79 [00:00<00:00, 101.88it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 13/20:  56%|███████████████████████████████████████████████████████▏                                           | 44/79 [00:00<00:00, 103.10it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 13/20:  70%|████████████████████████████████████████████████████████████████████▉                              | 55/79 [00:00<00:00, 103.92it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 13/20:  84%|██████████████████████████████████████████████████████████████████████████████████▋                | 66/79 [00:00<00:00, 104.42it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 13/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 103.21it/s]


x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([16, 100])
x out shape torch.Size([16, 100])
Epoch 13: avg MSE loss = 1.114994


Epoch 14/20:   0%|                                                                                                             | 0/79 [00:00<?, ?it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 14/20:  13%|████████████▋                                                                                       | 10/79 [00:00<00:00, 98.64it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 14/20:  27%|██████████████████████████▎                                                                        | 21/79 [00:00<00:00, 102.07it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 14/20:  41%|████████████████████████████████████████                                                           | 32/79 [00:00<00:00, 103.28it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 14/20:  54%|█████████████████████████████████████████████████████▉                                             | 43/79 [00:00<00:00, 104.22it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 14/20:  68%|███████████████████████████████████████████████████████████████████▋                               | 54/79 [00:00<00:00, 104.82it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 14/20:  82%|█████████████████████████████████████████████████████████████████████████████████▍                 | 65/79 [00:00<00:00, 104.91it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 14/20:  96%|███████████████████████████████████████████████████████████████████████████████████████████████▏   | 76/79 [00:00<00:00, 104.68it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([16, 100])


Epoch 14/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 104.23it/s]


x out shape torch.Size([16, 100])
Epoch 14: avg MSE loss = 0.894826


Epoch 15/20:   0%|                                                                                                             | 0/79 [00:00<?, ?it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 15/20:  14%|█████████████▊                                                                                     | 11/79 [00:00<00:00, 103.54it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 15/20:  28%|███████████████████████████▊                                                                        | 22/79 [00:00<00:00, 96.41it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 15/20:  41%|████████████████████████████████████████▌                                                           | 32/79 [00:00<00:00, 96.48it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 15/20:  54%|██████████████████████████████████████████████████████▍                                             | 43/79 [00:00<00:00, 97.63it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 15/20:  67%|███████████████████████████████████████████████████████████████████                                 | 53/79 [00:00<00:00, 97.16it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 15/20:  81%|█████████████████████████████████████████████████████████████████████████████████                   | 64/79 [00:00<00:00, 98.76it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 15/20:  94%|█████████████████████████████████████████████████████████████████████████████████████████████▋      | 74/79 [00:00<00:00, 84.07it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 15/20: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 92.07it/s]


x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([16, 100])
x out shape torch.Size([16, 100])
Epoch 15: avg MSE loss = 0.736765


Epoch 16/20:   0%|                                                                                                             | 0/79 [00:00<?, ?it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 16/20:  14%|█████████████▊                                                                                     | 11/79 [00:00<00:00, 101.78it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128

Epoch 16/20:  28%|███████████████████████████▌                                                                       | 22/79 [00:00<00:00, 100.81it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 16/20:  42%|█████████████████████████████████████████▊                                                          | 33/79 [00:00<00:00, 99.54it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 16/20:  56%|███████████████████████████████████████████████████████▏                                           | 44/79 [00:00<00:00, 100.31it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 16/20:  70%|█████████████████████████████████████████████████████████████████████▌                              | 55/79 [00:00<00:00, 98.91it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 16/20:  84%|███████████████████████████████████████████████████████████████████████████████████▌                | 66/79 [00:00<00:00, 99.69it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 16/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 100.48it/s]


x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([16, 100])
x out shape torch.Size([16, 100])
Epoch 16: avg MSE loss = 0.621993


Epoch 17/20:   0%|                                                                                                             | 0/79 [00:00<?, ?it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 17/20:  14%|█████████████▊                                                                                     | 11/79 [00:00<00:00, 101.48it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 17/20:  28%|███████████████████████████▌                                                                       | 22/79 [00:00<00:00, 101.52it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 17/20:  42%|█████████████████████████████████████████▎                                                         | 33/79 [00:00<00:00, 102.85it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 17/20:  56%|███████████████████████████████████████████████████████▏                                           | 44/79 [00:00<00:00, 103.51it/s]

x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 17/20:  70%|████████████████████████████████████████████████████████████████████▉                              | 55/79 [00:00<00:00, 103.89it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128

Epoch 17/20:  84%|██████████████████████████████████████████████████████████████████████████████████▋                | 66/79 [00:00<00:00, 104.27it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])


Epoch 17/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 103.86it/s]


x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([16, 100])
x out shape torch.Size([16, 100])
Epoch 17: avg MSE loss = 0.539775


Epoch 18/20:   0%|                                                                                                             | 0/79 [00:00<?, ?it/s]

x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])
x input shape torch.Size([128, 100])
x out shape torch.Size([128, 100])


Epoch 18/20:   9%|████████▉                                                                                            | 7/79 [00:00<00:00, 89.57it/s]


KeyboardInterrupt: 

## SPD Model and Train Function

In [8]:
# generated by LLM

class HardSigmoid(nn.Module):
    """
    Implements the hard sigmoid activation function as described in the paper:
        σ_H(x) = 0 if x <= 0
               = x if 0 < x < 1
               = 1 if x >= 1
    This is equivalent to: torch.clamp(x, min=0.0, max=1.0)
    """
    def __init__(self):
        super().__init__()

    def forward(self, x):
        # Clamp values between 0 and 1
        return torch.clamp(x, min=0.0, max=1.0)

tensor([0.0000, 0.0000, 0.0000, 0.3000, 0.7000, 1.0000, 1.0000])


In [None]:
class SPDModelMLP(nn.Module): 
    # fix the betas stuff below that's basically a type hint
    def __init__(self, target_model, config, device="mps"): 
        self.device = device
        self.target_model = target_model
        self.num_layers, self.in_size, self.hidden_size, self.imp_hidden_size = config["num_layers"], config["in_size"], config["hidden_size"], config["importance_mlp_size"]
        assert self.device == target_model.device, "Models not on same device"
        self.C = config["subcomponents_per_layer"]
        self.hypers = dict(list(config.items())[4:]) # sets the "hypers" to contain all the hyperparameters for the model
        
        # Subcomponent vectors, each of shape C by in_size; to be used
        # with outer product to create our low-rank subcomponent matrices
        self.V_in = nn.ParameterList([torch.empty((self.C, in_size,), device=device) for i in range(num_layers)])
        self.U_in = nn.ParameterList([torch.empty((self.C, hidden_size,), device=device) for i in range(num_layers)]) 
        self.V_out = nn.ParameterList([torch.empty((self.C, hidden_size,), device=device) for i in range(num_layers)])
        self.U_out = nn.ParameterList([torch.empty((self.C, in_size,), device=device) for i in range(num_layers)])
        
        # idk what you do with the biases lol
        self.b = nn.ParameterList([torch.zeros((hidden_size,), device=device) for i in range(num_layers)])
        
        # imp_W_in and out etc are the weights for the importance predictor
        # should be C networks per layer, mapping 1 -> C -> 1
        self.imp_W_gate_in = nn.ParameterList([torch.empty(C, 1, self.imp_hidden_size) for i in range(num_layers)])
        self.imp_W_gate_out = nn.ParameterList([torch.empty(C, self.imp_hidden_size, 1) for i in range(num_layers)])
        self.imp_b_in = nn.ParameterList([torch.empty(C, self.imp_hidden_size) for i in range(num_layers)])
        self.imp_b_out = nn.ParameterList([torch.empty(C, 1) for i in range(num_layers)])
        
        
        # self.pred_weights = torch.empty((num_layers, self.C), device=device)

        for param in self.V_in + self.U_in + self.V_out + self.U_out + self.b: 
            # using xavier anyway -- note that the variance etc is
            # changed because we take the outer product in the 
            # forward pass
            nn.init.xavier_normal_(param)

        
    def forward(self, x, return_activations=False, return_weight_matrices=False, masks=None): # Regular run. Unclear whether I should have masking when I do a regular forward pass.
        v_activations = []
        weight_matrices = []
        x_resid = x.clone()
        layerwise_resids = []
        
        for l in self.num_layers:

            if masks is not None: 
                # may have dimension issues in these einsums :(
                ### FULL MASK WEIGHTS
                v_activ_layer_in = torch.einsum("ni,ci->nc", x_resid, self.V_in[l]) * masks[l]["in"]
                layer_hidden = torch.einsum("nc,ch->nh", v_activ_layer_in, self.U_in[l])
                v_activ_layer_out = torch.einsum("nh,ch->nc", layer_in, self.V_out[l]) * masks[l]["out"]
                layer_out = torch.einsum("nc,ci->ni", v_activ_layer_in, self.U_out[l])
                
                x_resid_layerwise = x.clone()
                for l_2 in self.num_layers: # oh god this is so awful. literally sobbing rn. i can't believe I'm writing it like this
                    if l_2 == l: 
                        v_activ_layer_in_l = torch.einsum("ni,ci->nc", x_resid_layerwise, self.V_in[l]) * masks[l]["in"]
                        layer_hidden_l = torch.einsum("nc,ch->nh", v_activ_layer_in_l, self.U_in[l])
                        v_activ_layer_l = torch.einsum("nh,ch->nc", layer_in_l, self.V_out[l]) * masks[l]["out"]
                        layer_out_l = torch.einsum("nc,ci->ni", v_activ_layer_in_l, self.U_out[l])
                    else: # THE SAME EXACT THING EXCEPT WITHOUT THE MASKS >_< 
                        v_activ_layer_in_l = torch.einsum("ni,ci->nc", x_resid_layerwise, self.V_in[l])
                        layer_hidden_l = torch.einsum("nc,ch->nh", v_activ_layer_in_l, self.U_in[l])
                        v_activ_layer_l = torch.einsum("nh,ch->nc", layer_in_l, self.V_out[l])
                        layer_out_l = torch.einsum("nc,ci->ni", v_activ_layer_in_l, self.U_out[l])
                    x_resid_layerwise = x_resid_layerwise + layer_out_l
                    
                layerwise_resids.append(x_resid_layerwise)
                                
            else: 
                # Use outer product to create weights for the layer, then sum all the subcomponents
                W_in = torch.einsum("ci,ch-> cih", self.V_in[l], self.U_in[l]).sum(dim=0) # shape i h
                W_out = torch.einsum("ch,ci-> chi", self.V_out[l], self.U_out[l]).sum(dim=0) # shape h i
                if return_weight_matrices == True:
                    weight_matrices.append({"in": W_in, "out", W_out})
                # COMPUTE 
                layer_in = torch.einsum("ni,ih -> nh", x_resid, W_in) + self.b[l]
                layer_out = torch.einsum("nh,hi->ni", F.relu(layer_in), W_out)
            
            if return_activations == True: 
                # calculate activations
                v_activ_layer_in = torch.einsum("ni,ci->nc", x_resid, self.V_in[l]).unsqueeze(-1)
                v_activ_layer_out = torch.einsum("nh,ch->nc", layer_in, self.V_out[l]).unsqueeze(-1)
                # both are now shape n,c,1
                v_activations.append({"in": v_activ_layer_in, "out": v_activ_layer_out})
                
                # code to make sure I'm not doing some goofy shit with the activations at the beginning
                check_layer_in_uv = torch.einsum("nc,ch->nh", v_activ_layer_in, self.U_in[l])
                assert layer_in == check_layer_in_uv, "Subcomponent activations not calculated correctly"
                # would in theory have one for the second too but I'm being lazy
            
            x_resid = x_resid + layer_out

        if (return_activations == True) and (return_weight_matrices == True): # these will be used together. should probably merge them into 1 but won't right now
            return x_resid, v_activations, weight_matrices 
        else if masks is not None: 
            return x_resid, layerwise_resids
        else: 
            return x_resid
    


In [None]:
def generate_batch(config): 
    # return shape N, config.in_size 
    pass


def train_SPD(spd_model): # could also implement this by passing in the original model
    x = generate_batch() # TODO. possibly put in the model? idk if needed.
    N = x.shape[0]     # x is shape N by in_size
    C = model.C
    target_model = spd_model.target_model
    hard_sigmoid = HardSigmoid()
    
    # COMPUTE TARGET OUTPUT
    with torch.no_grad(): 
        target_out = spd_model.target_model(x)

    # FAITHFULNESS LOSS
    
    spd_output, spd_activations spd_weights = model(x, return_activations = True, return_weight_matrices = True)
    squared_error = 0
    for l in range(num_layers):
        in_diff = target_model.W_in[l] - spd_weights[l]["in"]
        out_diff = target_model.W_out[l] - spd_weights[l]["out"]
        
        # torch.linalg.matrix_norm defaults to the frobenius norm
        # this takes the frobenius norm of the diff and then squares it
        squared_error_layer = torch.linalg.matrix_norm(in_diff) ** 2 + torch.linalg.matrix_norm(out_diff) ** 2 
        squared_error = squared_error + squared_error_layer
        
    mean_squared_error = squared_error/num_layers
    l_faithfulness = mean_squared_error
    

    ## IMPORTANCE-MINIMALITY LOSS
    
    pred_importances = []
    l_importance_minimality = 0
    
    for l in range(num_layers):
        # both activations are n by c containing dot product so we already have hard_sigmoid
        # spd_activations[l][inout] is shape n,c,1 (nco)
        # imp_W_in is c by 1 by imp_size (cos)
        # want to map to ncs then back to nco
        # imp_b_in is shape (C, s) so should broadcast to ncs nicely
        
        # TODO: DEFINE hard_sigmoid AS A TORCH MODULE SO IT CAN CALCULATE THE DERIVATIVE 
        # in theory should write this as a bunch of models stored in the main model, but that's not how i did it and i've already 
        # written this, shrug
        components_imp_pred_hidden_in = F.gelu(torch.einsum("nco,cos->ncs", spd_activations[l]["in"], spd_model.imp_W_gate_in) + spd_model.imp_b_in)
        components_pred_layer_in = hard_sigmoid(torch.einsum("ncs,cso->nco", components_imp_pred_hidden, spd_model.imp_W_gate_out) + spd_model.imp_b_out)

        #same thing for the out matrix in layer l
        components_imp_pred_hidden_out = F.gelu(torch.einsum("nco,cos->ncs", spd_activations[l]["out"], spd_model.imp_W_gate_in) + spd_model.imp_b_in)
        components_pred_layer_out = hard_sigmoid(torch.einsum("ncs,cso->nco", components_imp_pred_hidden, spd_model.imp_W_gate_out) + spd_model.imp_b_out)
        
        pred_importances.append({"in": components_pred_layer_in, "out": components_pred_layer_out)
        l_importance_minimality = l_importance_minimality + (components_pred_layer_in ** spd_model.hypers["importance_mlp_size"]) + (components_pred_layer_out ** spd_model.hypers["importance_mlp_size"])
    
    l_importance_minimality /= N # divide by N, the batch size, to avg across the batch (notated as B in the paper)



    ## STOCHASTIC RECONSTRUCTION LOSS 

    l_stochastic_recon = 0
    l_stochastic_recon_layerwise = 0
    R = torch.rand(model.hypers["num_mask_samples"], N, model.num_layers, C)
    layer_masks = []

    for s in range(model.hypers["num_mask_samples"]):
        # Running this with a for loop. This is slow; I'd ideally just run 
        # something like M = G[None, :, :] + (1 - G[None, :, :]) * R
        # But I think there's a clarity tradeoff here so I'm just going to do this for now
        for l in range(num_layers):
            # ugh, this is pretty bad. Ideally I'd go back and refactor so that I don't 
            # have separate in and out weights, but this is what I have. going to keep it
            # like this for now, but will refactor once I have something that works.
            # components pred in is shape N, C, 1, squeeze to N, C. there are L of them
            # R is N, L, C, sample just l to get N, C
            # I want masks for each component on each layer on each datapoint, so
            # masks should be one C vector for each layer, per-datapoint. so NLC
            layer_mask_in = pred_importances[l]["in"].squeeze() + (torch.ones_like(pred_importances[l]["in"]) - pred_importances[l]["in"])*R[s,:,l,:]
            layer_mask_out = pred_importances[l]["out"].squeeze() + (torch.ones_like(pred_importances[l]["out"]) - pred_importances[l]["out"]) * R [s,:,l,:]
            # these are both shape N, C 
            layer_masks.append({"in": layer_mask_in, "out": layer_mask_out})
        masked_out, layerwise_masked_outs = model(x, masks=layer_masks)
        l_stochastic_recon = l_stochastic_recon + torch.linalg.matrix_norm(target_out - masked_out) # uses matrix norm of difference
        for i in range(len(layerwise_masked_outs)-1):
            l_stochastic_recon_layerwise = l_stochastic_recon_layerwise + torch.linalg.matrix_norm(target_out - layerwise_masked_outs[i])

    l_stochastic_recon /= model.hypers["num_mask_samples"]
    l_stochastic_recon_layerwise /= (model.hypers["num_mask_samples"] * model.num_layers)

    loss = l_faithfulness + model.hypers["beta1"] * l_stochastic_recon + model.hypers["beta2"] * l_stochastic_recon_layerwise + model.hypers["beta3"] * l_importance_minimality

    """
    TODO: 
    - generate batches
        - i want a dataloader object that i can iterate through
        - my first test will be the 1-layer 50-hidden 100-in dim
    - start testing
    - wrap functions with tqdm
    """

    