In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import tqdm
from math import log2
from typing import List, Tuple, Dict
import numpy as np
import pickle

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [2]:
def get_cifar_dataloader(
    train_batch_size: int = 256, 
    test_batch_size: int = 2048, 
    device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
) -> Tuple[DataLoader, DataLoader]:
    """
    Loads CIFAR-10 and returns DataLoader for training and testing (3-channel RGB).
    
    Args:
        train_batch_size: Batch size for training data
        test_batch_size: Batch size for test data
        device: Target device ('cuda' or 'cpu')
    
    Returns:
        Tuple of (train_loader, test_loader)
    """
    transform = transforms.Compose([
        transforms.ToTensor()  # Converts to tensor and normalizes to [0,1]
    ])

    train_dataset = datasets.CIFAR10(
        root='./data', 
        train=True, 
        transform=transform, 
        download=True
    )
    test_dataset = datasets.CIFAR10(
        root='./data', 
        train=False, 
        transform=transform, 
        download=True
    )

    # Convert to tensors (N, C, H, W) format
    train_data = torch.stack([x for x, _ in train_dataset], dim=0)
    test_data = torch.stack([x for x, _ in test_dataset], dim=0)

    if device.type == 'cuda':
        train_data = train_data.to(device)
        test_data = test_data.to(device)

    train_dataset = TensorDataset(train_data)
    test_dataset = TensorDataset(test_data)

    if device.type.startswith("cuda"):
        train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)
    else:
        train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=8, pin_memory=True)
        test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, num_workers=8, pin_memory=True)

    return train_loader, test_loader

In [3]:
def get_fashion_mnist_dataloader(
    train_batch_size: int = 256, 
    test_batch_size: int = 2048, 
    device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
) -> Tuple[DataLoader, DataLoader]:
    """
    Loads Fashion-MNIST and returns DataLoader for training and testing (1-channel grayscale).
    
    Args:
        train_batch_size: Batch size for training data
        test_batch_size: Batch size for test data
        device: Target device ('cuda' or 'cpu')
    
    Returns:
        Tuple of (train_loader, test_loader)
    """
    transform = transforms.Compose([
        transforms.ToTensor()  # Converts to tensor and normalizes to [0,1]
    ])

    train_dataset = datasets.FashionMNIST(
        root='./data', 
        train=True, 
        transform=transform, 
        download=True
    )
    test_dataset = datasets.FashionMNIST(
        root='./data', 
        train=False, 
        transform=transform, 
        download=True
    )

    # Convert to tensors (N, C, H, W) format
    train_data = torch.stack([x for x, _ in train_dataset], dim=0)
    test_data = torch.stack([x for x, _ in test_dataset], dim=0)

    if device.type == 'cuda':
        train_data = train_data.to(device)
        test_data = test_data.to(device)

    train_dataset = TensorDataset(train_data)
    test_dataset = TensorDataset(test_data)

    if device.type.startswith("cuda"):
        train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)
    else:
        train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=8, pin_memory=True)
        test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, num_workers=8, pin_memory=True)

    return train_loader, test_loader

In [4]:
class VAE(nn.Module):
    """Variational Autoencoder for CIFAR-10 (3-channel input)."""
    
    def __init__(self, input_channels: int = 3, input_size: int = 32, 
                 filter_sizes: List[int] = [32, 64, 128], latent_dim: int = 64, kernel_size: int = 3,
                 stride: int = 2, padding: int = 1, is_variational: bool = True):
        super().__init__()
        self.latent_dim = latent_dim
        self.input_channels = input_channels
        self.input_size = input_size
        self.filter_sizes = filter_sizes
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.is_variational = is_variational
        
        self._build_encoder()
        self._calculate_conv_output_shape()
        self._build_latent_layers()
        self._build_decoder()

    def _build_encoder(self) -> None:
        encoder_layers = []
        in_channels = self.input_channels
        
        for fs in self.filter_sizes:
            encoder_layers.extend([
                nn.Conv2d(in_channels, fs, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding),
                nn.BatchNorm2d(fs),
                nn.ReLU()
            ])
            in_channels = fs
            
        self.encoder = nn.Sequential(*encoder_layers)

    def _calculate_conv_output_shape(self) -> None:
        """
        Вычисляет размер изображения после применения `num_convs` свёрточных слоёв (Conv2d).

        Args:
            input_size (int): Размер входа (N для изображения N x N).
            kernel_size (int): Размер ядра свёртки (K).
            stride (int): Шаг свёртки (S).
            padding (int): Паддинг (P).
            num_convs (int): Количество свёрточных слоёв (M).

        Returns:
            int: Размер изображения после M свёрток.
        """
        current_size = self.input_size
        self.data_sizes = [current_size]
        for _ in self.filter_sizes:
            current_size = int(np.floor((current_size + 2 * self.padding - self.kernel_size) / self.stride + 1))
            self.data_sizes.append(current_size)

        self.flat_dim = self.filter_sizes[-1] * current_size * current_size
        self.conv_out_shape = torch.Size([int(self.filter_sizes[-1]), current_size, current_size]) # (C, H, W)

    def _build_latent_layers(self) -> None:
        self.fc_mu = nn.Linear(self.flat_dim, self.latent_dim)
        self.fc_logvar = nn.Linear(self.flat_dim, self.latent_dim)
        self.fc_decode = nn.Linear(self.latent_dim, self.flat_dim)

    def _build_decoder(self) -> None:
        decoder_layers = []
        reversed_filters = list(reversed(self.filter_sizes))
        in_channels = reversed_filters[0]
        
        for i, fs in enumerate(reversed_filters[1:] + [self.input_channels]):
            decoder_layers.append(
                nn.ConvTranspose2d(
                    in_channels, 
                    fs, 
                    kernel_size=self.kernel_size, 
                    stride=self.stride, 
                    padding=self.padding,
                    output_padding=self.data_sizes[::-1][i+1] - ((self.data_sizes[::-1][i] - 1) * self.stride - 2 * self.padding + self.kernel_size)
                )
            )
            if i < len(reversed_filters) - 1:  # No BN/ReLU on last layer
                decoder_layers.append(nn.BatchNorm2d(fs))
                decoder_layers.append(nn.ReLU())
            in_channels = fs
            
        decoder_layers.append(nn.Sigmoid())  # Final activation
        self.decoder = nn.Sequential(*decoder_layers)

    def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        x = self.encoder(x)
        # print("x = self.encoder(x)", x.size())
        x_flat = x.view(x.size(0), -1)
        # print("x_flat = x.view(x.size(0), -1)", x_flat.size())
        self.encoder_exit = x_flat # Нужно для вычисления mi
        return self.fc_mu(x_flat), self.fc_logvar(x_flat)

    def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        if self.is_variational:
            return mu + eps * std
        else:
            return mu

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        # print("z", z.size())
        x = self.fc_decode(z)
        # print("x = self.fc_decode(z)", x.size())
        self.decoder_input = x # Нужно для вычисления mi
        x = x.view(z.size(0), *self.conv_out_shape)
        # print("x = x.view(z.size(0), *self.conv_out_shape)", x.size())
        return self.decoder(x)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

def vae_loss(recon_x: torch.Tensor, x: torch.Tensor, 
             mu: torch.Tensor, logvar: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Compute VAE loss (MSE reconstruction + KL divergence)."""
    recon_loss = F.mse_loss(recon_x, x, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    batch_size = x.size(0)
    return (recon_loss + kl_loss)/batch_size, recon_loss/batch_size, kl_loss/batch_size

In [5]:
def calculate_model_stats(vae_model: VAE) -> dict:
    """Calculate all model statistics (receptive field, compression ratio, etc.)."""
    stats = {}

    # 1. Basic model info
    stats["latent_dim"] = vae_model.latent_dim
    stats["input_shape"] = [vae_model.input_channels, vae_model.input_size, vae_model.input_size]
    stats["filter_sizes"] = vae_model.filter_sizes
    stats["n_conv_blocks"] = len(vae_model.filter_sizes)

    # 2. Encoder/Decoder parameters
    total_params = sum(p.numel() for p in vae_model.parameters())
    encoder_params = sum(p.numel() for p in vae_model.encoder.parameters())
    decoder_params = sum(p.numel() for p in vae_model.decoder.parameters())
    stats["total_params"] = total_params
    stats["encoder_params"] = encoder_params
    stats["decoder_params"] = decoder_params
    stats["encoder_decoder_ratio"] = encoder_params / decoder_params

    # 3. Channel statistics
    stats["max_channels"] = max(vae_model.filter_sizes)
    stats["min_channels"] = min(vae_model.filter_sizes)
    ## ЕСЛИ ЗАРАБОТАЕТ, ТО ПОТОМ ПЕРЕПИСАТЬ НОРМАЛЬНО!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    if len(vae_model.filter_sizes) > 1:
        stats["channel_growth_factor"] = stats["max_channels"] / stats["min_channels"]
    else:
        stats["channel_growth_factor"] = 1.0

    # 4. Compression ratio (spatial + channel)
    encoder_output_shape = vae_model.conv_out_shape  # (C, H, W)
    input_pixels = vae_model.input_size * vae_model.input_size * vae_model.input_channels
    compressed_pixels = encoder_output_shape[0] * encoder_output_shape[1] * encoder_output_shape[2]
    stats["compression_ratio"] = input_pixels / compressed_pixels
    stats["encoder_output_shape"] = list(encoder_output_shape)

    # 5. Receptive field calculation (for encoder)
    stats["encoder_receptive_field"] = calculate_receptive_field(vae_model.encoder)
    stats["kernel_size"] = vae_model.kernel_size
    stats["stride"] = vae_model.stride
    stats["padding"] = vae_model.padding
    
    return stats


def calculate_receptive_field(encoder: nn.Sequential) -> int:
    """Calculate the receptive field of the encoder's last layer."""
    rf = 1
    stride_product = 1
    for layer in encoder:
        if isinstance(layer, nn.Conv2d):
            kernel_size = layer.kernel_size[0]
            stride = layer.stride[0]
            padding = layer.padding[0]
            rf = rf + (kernel_size - 1) * stride_product
            stride_product *= stride
    return rf

In [6]:
import numpy as np
import numpy.linalg as la
from numpy import log
from scipy.special import digamma
from sklearn.neighbors import BallTree, KDTree


# Continuous Estimators
def entropy(x, k=3, base=2):
    """ The classic K-L k-nearest neighbor continuous entropy estimator
        x should be a list of vectors, e.g. x = [[1.3], [3.7], [5.1], [2.4]]
        if x is a one-dimensional scalar and we have four samples
    """
    assert k <= len(x) - 1, "Set k smaller than num. samples - 1"
    x = np.asarray(x)
    n_elements, n_features = x.shape
    x = add_noise(x)
    tree = build_tree(x)
    nn = query_neighbors(tree, x, k)
    const = digamma(n_elements) - digamma(k) + n_features * log(2)
    return (const + n_features * np.log(nn).mean()) / log(base)

def mi(x, y, z=None, k=3, base=2):
    """ Mutual information of x and y (conditioned on z if z is not None)
        x, y should be a list of vectors, e.g. x = [[1.3], [3.7], [5.1], [2.4]]
        if x is a one-dimensional scalar and we have four samples
    """
    assert len(x) == len(y), "Arrays should have same length"
    assert k <= len(x) - 1, "Set k smaller than num. samples - 1"
    x, y = np.asarray(x), np.asarray(y)
    x, y = x.reshape(x.shape[0], -1), y.reshape(y.shape[0], -1)
    x = add_noise(x) #add noise to both of x and y
    y = add_noise(y)
    points = [x, y]
    if z is not None:
        z = np.asarray(z)
        z = z.reshape(z.shape[0], -1)
        points.append(z)
    points = np.hstack(points)
    # Find nearest neighbors in joint space, p=inf means max-norm
    tree = build_tree(points)
    dvec = query_neighbors(tree, points, k)
    if z is None:
        a, b, c, d = avgdigamma(x, dvec), avgdigamma(
            y, dvec), digamma(k), digamma(len(x))
    else:
        xz = np.c_[x, z]
        yz = np.c_[y, z]
        a, b, c, d = avgdigamma(xz, dvec), avgdigamma(
            yz, dvec), avgdigamma(z, dvec), digamma(k)
    return (-a - b + c + d) / log(base)


# Discrete Estimators

def add_noise(x, intens=1e-10):
    # small noise to break degeneracy, see doc.
    return x + intens * np.random.random_sample(x.shape)


def query_neighbors(tree, x, k):
    return tree.query(x, k=k + 1)[0][:, k]


def count_neighbors(tree, x, r):
    return tree.query_radius(x, r, count_only=True)


def avgdigamma(points, dvec):
    # This part finds number of neighbors in some radius in the marginal space
    # returns expectation value of <psi(nx)>
    tree = build_tree(points)
    dvec = dvec - 1e-15
    num_points = count_neighbors(tree, points, dvec)
    return np.mean(digamma(num_points))


def build_tree(points):
    if points.shape[1] >= 20:
        return BallTree(points, metric='chebyshev')
    return KDTree(points, metric='chebyshev')

def get_unique_probs(x):
    uniqueids = np.ascontiguousarray(x).view(np.dtype((np.void, x.dtype.itemsize * x.shape[1])))
    _, unique_inverse, unique_counts = np.unique(uniqueids, return_index=False, return_inverse=True, return_counts=True)
    return np.asarray(unique_counts / float(sum(unique_counts))), unique_inverse


In [7]:
import random
import time
import string
import pickle
import os
import numpy as np
import torch
from scipy.stats import kurtosis
from tqdm import tqdm
from sklearn.preprocessing import StandardScaler
from sklearn.cross_decomposition import CCA

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from typing import Dict
from scipy.stats import kurtosis

def generate_random_hash(length=16):
    """Генерирует случайный хэш из цифр и букв (A-Z, a-z)"""
    characters = string.ascii_letters + string.digits  # Все буквы и цифры
    return ''.join(random.choice(characters) for _ in range(length))

def canonical_correlation(A, B):
    """Возвращает максимальную каноническую корреляцию между A и B."""
    assert A.shape[0] == B.shape[0]
    cca = CCA(n_components=1)
    cca.fit(A, B)
    U, V = cca.transform(A, B)
    return np.corrcoef(U.T, V.T)[0, 1]


def train_and_evaluate(
    config_id: int,
    output_dir: str,
    train_loader: DataLoader,
    test_loader: DataLoader,
    latent_dim: int,
    enc_filters: int,
    device: torch.device,
    experement_hash: str,
    is_variational=True,
    kernel_size: int = 3,
    stride: int = 2,
    padding: int = 1,
    epochs: int = 10,
    input_channels: int = 3,
    input_size: int = 32,
    checkpoint_epochs: List[int] = [1, 5]
) -> Dict:
    """Train and evaluate with metrics saving at specified epochs."""
    
    print(f"\nConfig {config_id}: latent_dim={latent_dim}, filters={enc_filters}, kernel_size={kernel_size}, stride={stride}, padding={padding}")
    
    model = VAE(
        input_channels=input_channels,
        input_size=input_size,
        filter_sizes=enc_filters,
        latent_dim=latent_dim,
        kernel_size = kernel_size,
        stride = stride,
        padding = padding,
        is_variational = is_variational
    ).to(device)
    
    # Initialize optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=2e-4, weight_decay=1e-5)
    
    # Results storage
    results = {
        'config': calculate_model_stats(model), 
        'checkpoints': {}
    }
    
    prev_calc_time = []
    for epoch in tqdm(range(1, epochs + 1)):
        start_time = time.time()
        model.train()
        epoch_loss = 0.0
        epoch_rec = 0.0
        epoch_kl = 0.0
        
        for x, in train_loader:
            x = x.to(device)
            
            optimizer.zero_grad()
            recon, mu, logvar = model(x)
            loss, rec_loss, kl_loss = vae_loss(recon, x, mu, logvar)

            if recon.isnan().any():
                print("Модель не обучилась на поданной конфигурации")
                raise ValueError("Модель не обучилась на поданной конфигурации")
            
            if is_variational:
                loss.backward()
            else:
                rec_loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            epoch_rec += rec_loss.item()
            epoch_kl += kl_loss.item()
        
        prev_calc_time.append(time.time() - start_time)
        if len(prev_calc_time) >= 3:
            if np.mean(prev_calc_time[-3:]) > 30 * 60:
                print("Рассчет эпохи слишком длительный, чтобы можно было использовать его в эксперименте!")
                raise TimeoutError("Рассчет эпохи слишком длительный, чтобы можно было использовать его в эксперименте!")
        # Save checkpoint if epoch is in checkpoint_epochs
        if epoch in checkpoint_epochs:
            test_metrics = evaluate_model(model, test_loader, device)
            results['checkpoints'][epoch] = {
                'train_loss': epoch_loss / len(train_loader),
                'train_rec': epoch_rec / len(train_loader),
                'train_kl': epoch_kl / len(train_loader),
                **test_metrics,
                'model_state': model.state_dict()
            }
    
    # Final evaluation
    if epoch not in results['checkpoints']:
        test_metrics = evaluate_model(model, test_loader, device)
        results['checkpoints'][epoch] = {
            'train_loss': epoch_loss / len(train_loader),
            'train_rec': epoch_rec / len(train_loader),
            'train_kl': epoch_kl / len(train_loader),
            **test_metrics,
            'model_state': model.state_dict()
        }
    
    # Save results
    os.makedirs(output_dir, exist_ok=True)
    with open(f"{output_dir}/config_{experement_hash}_{config_id}.pkl", 'wb') as f:
        pickle.dump(results, f)
    
    return results


def evaluate_model(
    model: nn.Module,
    test_loader: DataLoader,
    device: torch.device,
    subsample_factor: int = 15,
) -> Dict:
    """Evaluate model on test set with multiple metrics."""
    model.eval()
    total_loss, rec_loss, kl_loss = 0., 0., 0.
    all_x, all_encoder_exit, all_decoder_input, all_recon, all_mu = [], [], [], [], []
    
    with torch.no_grad():
        for x, in test_loader:
            # x = x.to(device)
            recon, mu, logvar = model(x)
            
            # Get intermediate activations
            # encoder_exit = model.encoder_exit
            # decoder_input = model.decoder_input
            
            # Calculate losses
            loss, rl, kl = vae_loss(recon, x, mu, logvar)
            total_loss += loss.item()
            rec_loss += rl.item()
            kl_loss += kl.item()
            
            # Store for MI calculation
            all_x.append(x.cpu().numpy())
            all_recon.append(recon.cpu().numpy())
            # all_encoder_exit.append(encoder_exit.cpu().numpy())
            # all_decoder_input.append(decoder_input.cpu().numpy())
            all_mu.append(mu.cpu().numpy())
    
    # Calculate averages
    num_batches = len(test_loader)
    metrics = {
        'total': total_loss / num_batches,
        'rec': rec_loss / num_batches,
        'kl': kl_loss / num_batches,
    }
    
    # Concatenate all batches and subsample for speed
    all_x = np.vstack(all_x)[::subsample_factor]
    all_x = StandardScaler().fit_transform(all_x.reshape(all_x.shape[0], -1).T).T
    all_recon = np.vstack(all_recon)[::subsample_factor]
    all_recon = StandardScaler().fit_transform(all_recon.reshape(all_recon.shape[0], -1).T).T
    # all_encoder_exit = np.vstack(all_encoder_exit)[::subsample_factor]
    # all_decoder_input = np.vstack(all_decoder_input)[::subsample_factor]
    all_mu = np.vstack(all_mu)[::subsample_factor]
    all_mu = StandardScaler().fit_transform(all_mu.reshape(all_mu.shape[0], -1).T).T
    
    # Compute mutual information metrics
    # try:
    #     metrics.update({
    #         'mi_x_encoder_exit': float(mi(all_x, all_encoder_exit)),
    #     })
    # except:
    #     metrics.update({
    #         'mi_x_encoder_exit': None,
    #     })
    # try:
    #     metrics.update({
    #         'mi_encoder_exit_decoder_input': float(mi(all_encoder_exit, all_decoder_input)),
    #     })
    # except:
    #     metrics.update({
    #         'mi_encoder_exit_decoder_input': None,
    #     })
    # try:
    #     metrics.update({
    #         'mi_decoder_input_recon': float(mi(all_decoder_input, all_recon)),
    #     })
    # except:
    #     metrics.update({
    #         'mi_decoder_input_recon': None,
    #     })
    try:
        metrics.update({
            'mi_x_recon': float(mi(all_x, all_recon)),
        })
    except:
        print("mi_x_recon:", all_x[0], all_recon[0])
        metrics.update({
            'mi_x_recon': None,
        })
    try:
        metrics.update({
            'mi_x_mu': float(mi(all_x, all_mu)),
        })
    except:
        print("mi_x_mu:", all_x[0], all_mu[0])
        metrics.update({
            'mi_x_mu': None,
        })
    try:
        metrics.update({
            'mi_mu_recon': float(mi(all_mu, all_recon)),
        })
    except:
        print("mi_mu_recon:", all_mu[0], all_recon[0])
        metrics.update({
            'mi_mu_recon': None,
        })
    try:
        metrics.update({
            'cca_x_recon': float(canonical_correlation(all_x, all_recon)),
        })
    except:
        print("cca_x_recon:", all_x[0], all_recon[0])
        metrics.update({
            'cca_x_recon': None,
        })
    try:
        metrics.update({
            'cca_x_mu': float(canonical_correlation(all_x, all_mu)),
        })
    except:
        print("cca_x_mu:", all_x[0], all_mu[0])
        metrics.update({
            'cca_x_mu': None,
        })
    try:
        metrics.update({
            'cca_mu_recon': float(canonical_correlation(all_mu, all_recon)),
        })
    except:
        print("cca_mu_recon:", all_mu[0], all_recon[0])
        metrics.update({
            'cca_mu_recon': None,
        })

    
    # Extract weights for each part of the model
    encoder_weights = []
    flow_weights = []
    decoder_weights = []
    
    with torch.no_grad():
        
        for name, param in model.named_parameters():
            # Get encoder weights (conv layers)
            if 'encoder' in name and 'weight' in name:
                encoder_weights.append(param.cpu().detach().numpy().flatten())

            # Get flow weights (fc layers)
            if ('fc_mu' in name or 'fc_logvar' in name or 'fc_decode' in name) and 'weight' in name:
                flow_weights.append(param.cpu().detach().numpy().flatten())

            # Get decoder weights (conv transpose layers)
            if 'decoder' in name and 'weight' in name:  
                decoder_weights.append(param.cpu().detach().numpy().flatten())
            
    
    # Concatenate all weights
    encoder_weights = np.concatenate(encoder_weights)
    flow_weights = np.concatenate(flow_weights)
    decoder_weights = np.concatenate(decoder_weights)
    
    # Compute kurtosis for weights
    metrics.update({
        'kurtosis_encoder_weights': kurtosis(encoder_weights),
        'kurtosis_flow_weights': kurtosis(flow_weights),
        'kurtosis_decoder_weights': kurtosis(decoder_weights)
    })
    
    # Compute quantiles for weights [10%, 30%, 50%, 70%, 90%]
    quantiles = [0.1, 0.3, 0.5, 0.7, 0.9]
    
    metrics.update({
        **{f'quantile_{q}_encoder_weights': vq for q, vq in zip(quantiles, np.quantile(np.abs(encoder_weights), quantiles).tolist())},
        **{f'quantile_{q}_flow_weights': vq for q, vq in zip(quantiles, np.quantile(np.abs(flow_weights), quantiles).tolist())},
        **{f'quantile_{q}_decoder_weights': vq for q, vq in zip(quantiles, np.quantile(np.abs(decoder_weights), quantiles).tolist())}
    })
    
    return {"test_" + k: v for k, v in metrics.items()}

def calculate_max_num_conv_layers(input_size: int, kernel_size: int, stride: int, padding: int) -> int:
    """
    Вычисляет число слоев после которых раз изображения станет равно 1х1

    Args:
        input_size (int): Размер входа (N для изображения N x N).
        kernel_size (int): Размер ядра свёртки (K).
        stride (int): Шаг свёртки (S).
        padding (int): Паддинг (P).

    Returns:
        int: Размер изображения после M свёрток.
    """
    current_size = input_size
    prev_size = -1
    max_num_layers = 0
    while (current_size != prev_size) & (current_size > 1):
        prev_size = current_size
        max_num_layers += 1
        current_size = int(np.floor((prev_size + 2 * padding - kernel_size) / stride + 1))
    return max_num_layers - 1

## CIFAR

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader, test_loader = get_cifar_dataloader(device=device, train_batch_size=1000, test_batch_size=5000)
experement_hash = generate_random_hash(length=16)
print(experement_hash)

for i in range(0, 1000):
    kernel_size = np.random.randint(3, 7)
    stride = np.random.randint(1, kernel_size)
    if stride == 1:
        padding = np.random.randint(0, kernel_size // 2)
    else:
        padding = np.random.randint(0, kernel_size // 2 + 1)

    max_num_layers = min(5, calculate_max_num_conv_layers(input_size = 32, kernel_size = kernel_size, stride = stride, padding = padding))
    num_layers = np.random.randint(1, max_num_layers + 1)

    latent_dim = 8 * np.random.randint(2, 65)
    # enc_filters = sorted([round(1.25 ** np.random.randint(5, 25)) for _ in range(num_layers)])
    enc_filters = random.sample(list(range(5, 25)), num_layers)
    enc_filters = sorted([round(1.25 ** x) for x in enc_filters])

    max_epoch = 100

    # Train and evaluate single configuration
    try:
        results = train_and_evaluate(
            config_id=i,
            output_dir="./results_cifar_AE",
            train_loader = train_loader,
            test_loader = test_loader,
            latent_dim = latent_dim,
            enc_filters = enc_filters,
            device = device,
            experement_hash = experement_hash,
            is_variational=False,
            kernel_size = kernel_size,
            stride = stride,
            padding = padding,
            epochs = max_epoch,
            input_channels = 3,
            input_size = 32,    
            checkpoint_epochs = [1, 5]  # Сохранить метрики на 1 и 10 эпохах
        )
    except TimeoutError:
        continue
    except ValueError:
        continue
    except torch.OutOfMemoryError:
        print("Получившаяся конфигурация переполнила панять GPU!")
        torch.cuda.empty_cache()
        continue

## MNIST

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader, test_loader = get_fashion_mnist_dataloader(device=device, train_batch_size=1000, test_batch_size=5000)
experement_hash = generate_random_hash(length=16)
print(experement_hash)

for i in range(0, 1000):
    kernel_size = np.random.randint(3, 6)
    stride = np.random.randint(1, kernel_size)
    if stride == 1:
        padding = np.random.randint(0, kernel_size // 2)
    else:
        padding = np.random.randint(0, kernel_size // 2 + 1)

    max_num_layers = min(5, calculate_max_num_conv_layers(input_size = 28, kernel_size = kernel_size, stride = stride, padding = padding))
    num_layers = np.random.randint(1, max_num_layers + 1)

    latent_dim = 8 * np.random.randint(2, 65)
    # enc_filters = sorted([round(1.25 ** np.random.randint(5, 25)) for _ in range(num_layers)])
    enc_filters = random.sample(list(range(5, 25)), num_layers)
    enc_filters = sorted([round(1.25 ** x) for x in enc_filters])

    max_epoch = 100

    # Train and evaluate single configuration
    try:
        results = train_and_evaluate(
            config_id=i,
            output_dir="./results_fmnist_AE",
            train_loader = train_loader,
            test_loader = test_loader,
            latent_dim = latent_dim,
            enc_filters = enc_filters,
            device = device,
            experement_hash = experement_hash,
            is_variational=False,
            kernel_size = kernel_size,
            stride = stride,
            padding = padding,
            epochs = max_epoch,
            input_channels = 1,
            input_size = 28,    
            checkpoint_epochs = [1, 5]  # Сохранить метрики на 1 и 10 эпохах
        )
    except TimeoutError:
        continue
    except ValueError:
        continue
    except torch.OutOfMemoryError:
        print("Получившаяся конфигурация переполнила память GPU!")
        torch.cuda.empty_cache()
        continue