In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import os
import sys

# --- –ö–û–ù–§–Ü–ì–£–†–ê–¶–Ü–Ø ---
FAST_MODE = True
SUBSET_SIZE = 14500  # –¢—ñ–ª—å–∫–∏ 500 –∑–æ–±—Ä–∞–∂–µ–Ω—å –¥–ª—è —Ç—Ä–µ–Ω—É–≤–∞–Ω–Ω—è!
TEST_SIZE = 1500    # –¢—ñ–ª—å–∫–∏ 100 –¥–ª—è —Ç–µ—Å—Ç—É
BATCH_SIZE = 32
EPOCHS = 8         # –î—É–∂–µ –º–∞–ª–æ –µ–ø–æ—Ö!
LATENT_DIM = 16    # –î—É–∂–µ –º–∞–ª–∏–π latent space
NOISE_FACTOR = 0.3  # –†—ñ–≤–µ–Ω—å —à—É–º—É –¥–ª—è –¥–µ–Ω–æ—ó–∑–∏–Ω–≥—É

# --- Cell 1: –ê–†–•–Ü–¢–ï–ö–¢–£–†–ê –ê–í–¢–û–ï–ù–ö–û–î–ï–†–ê ---
class FastAutoencoder(nn.Module):
    def __init__(self, latent_dim=16):
        super(FastAutoencoder, self).__init__()
        
        # –ï–Ω–∫–æ–¥–µ—Ä - —Å—Ç–∏—Å–Ω–µ–Ω–Ω—è –∑–æ–±—Ä–∞–∂–µ–Ω–Ω—è
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 8, kernel_size=3, stride=2, padding=1),  # 64x64 -> 32x32
            nn.ReLU(),
            nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1), # 32x32 -> 16x16
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(16 * 16 * 16, latent_dim)  # –°—Ç–∏—Å–Ω–µ–Ω–Ω—è –¥–æ –ª–∞—Ç–µ–Ω—Ç–Ω–æ–≥–æ –ø—Ä–æ—Å—Ç–æ—Ä—É
        )
        
        # –î–µ–∫–æ–¥–µ—Ä - –≤—ñ–¥–Ω–æ–≤–ª–µ–Ω–Ω—è –∑–æ–±—Ä–∞–∂–µ–Ω–Ω—è
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 16 * 16 * 16),
            nn.ReLU(),
            nn.Unflatten(1, (16, 16, 16)),
            nn.ConvTranspose2d(16, 8, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(8, 3, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()  # –ù–æ—Ä–º–∞–ª—ñ–∑–∞—Ü—ñ—è –ø—ñ–∫—Å–µ–ª—ñ–≤ –¥–æ [0,1]
        )
    
    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)

# --- Cell 2: –ó–ê–í–ê–ù–¢–ê–ñ–ï–ù–ù–Ø –î–ê–ù–ò–• ---
def load_tiny_dataset(data_path, img_size=64):
    """–ó–∞–≤–∞–Ω—Ç–∞–∂—É—î–º–æ –≤–∏–±—ñ—Ä–∫—É –¥–∞–Ω–∏—Ö –¥–ª—è —à–≤–∏–¥–∫–æ–≥–æ –µ–∫—Å–ø–µ—Ä–∏–º–µ–Ω—Ç—É"""
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
    ])
    
    try:
        # –ü–µ—Ä–µ–≤—ñ—Ä–∫–∞ —á–∏ —ñ—Å–Ω—É—î —à–ª—è—Ö
        if not os.path.exists(data_path):
            raise FileNotFoundError(f"–®–ª—è—Ö {data_path} –Ω–µ —ñ—Å–Ω—É—î")
            
        full_dataset = ImageFolder(root=data_path, transform=transform)
        print(f"–£—Å–ø—ñ—à–Ω–æ –∑–∞–≤–∞–Ω—Ç–∞–∂–µ–Ω–æ: {len(full_dataset)} –∑–æ–±—Ä–∞–∂–µ–Ω—å")
        print(f"   –ö–ª–∞—Å–∏: {full_dataset.classes}")
        
        # –ë–µ—Ä–µ–º–æ –ª–∏—à–µ –Ω–µ–≤–µ–ª–∏–∫—É –≤–∏–±—ñ—Ä–∫—É –¥–ª—è —à–≤–∏–¥–∫–æ—Å—Ç—ñ
        total_samples = min(SUBSET_SIZE + TEST_SIZE, len(full_dataset))
        indices = torch.randperm(len(full_dataset))[:total_samples]
        
        train_size = SUBSET_SIZE
        train_indices = indices[:train_size]
        test_indices = indices[train_size:train_size + TEST_SIZE]
        
        train_dataset = Subset(full_dataset, train_indices)
        test_dataset = Subset(full_dataset, test_indices)
        
        print(f" –í–∏–∫–æ—Ä–∏—Å—Ç–æ–≤—É—î–º–æ: {len(train_dataset)} —Ç—Ä–µ–Ω—É–≤–∞–ª—å–Ω–∏—Ö, {len(test_dataset)} —Ç–µ—Å—Ç–æ–≤–∏—Ö")
        
        return train_dataset, test_dataset, full_dataset.classes
        
    except Exception as e:
        print(f" –ü–æ–º–∏–ª–∫–∞ –∑–∞–≤–∞–Ω—Ç–∞–∂–µ–Ω–Ω—è: {e}")
        print("–°—Ç–≤–æ—Ä—é—î–º–æ –¥–µ–º–æ-–¥–∞–Ω—ñ.")
        from torchvision.datasets import FakeData
        train_dataset = FakeData(size=SUBSET_SIZE, image_size=(3, 64, 64), num_classes=3, transform=transforms.ToTensor())
        test_dataset = FakeData(size=TEST_SIZE, image_size=(3, 64, 64), num_classes=3, transform=transforms.ToTensor())
        class_names = ['cat', 'dog', 'wild']
        
        print(f"üìä –î–µ–º–æ-–¥–∞–Ω—ñ: {len(train_dataset)} —Ç—Ä–µ–Ω—É–≤–∞–ª—å–Ω–∏—Ö, {len(test_dataset)} —Ç–µ—Å—Ç–æ–≤–∏—Ö")
        
        return train_dataset, test_dataset, class_names

# –§—É–Ω–∫—Ü—ñ—è –¥–ª—è –¥–æ–¥–∞–≤–∞–Ω–Ω—è —à—É–º—É (–¥–ª—è –¥–µ–Ω–æ—ó–∑–∏–Ω–≥—É)
def add_noise(images, noise_factor=0.3):
    """–î–æ–¥–∞—î –≤–∏–ø–∞–¥–∫–æ–≤–∏–π —à—É–º –¥–æ –∑–æ–±—Ä–∞–∂–µ–Ω—å"""
    noise = torch.randn_like(images) * noise_factor
    noisy_images = images + noise
    return torch.clamp(noisy_images, 0.0, 1.0)

# –ó–∞–≤–∞–Ω—Ç–∞–∂—É—î–º–æ –¥–∞–Ω—ñ
DATA_PATH = "D:/Kotopes/Kotopes/data"
print(f" –ó–∞–≤–∞–Ω—Ç–∞–∂–µ–Ω–Ω—è –¥–∞–Ω–∏—Ö –∑: {DATA_PATH}")
train_dataset, test_dataset, class_names = load_tiny_dataset(DATA_PATH)

# --- Cell 3: –¢–†–ï–ù–£–í–ê–ù–ù–Ø –ê–í–¢–û–ï–ù–ö–û–î–ï–†–ê ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f" –í–∏–∫–æ—Ä–∏—Å—Ç–æ–≤—É—î—Ç—å—Å—è –ø—Ä–∏—Å—Ç—Ä—ñ–π: {device}")

model = FastAutoencoder(latent_dim=LATENT_DIM).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

def fast_train(model, train_loader, test_loader, epochs=EPOCHS, denoise_mode=False):
    """–¢—Ä–µ–Ω—É–≤–∞–Ω–Ω—è –∞–≤—Ç–æ–µ–Ω–∫–æ–¥–µ—Ä–∞ –≤ –∑–≤–∏—á–∞–π–Ω–æ–º—É –∞–±–æ –¥–µ–Ω–æ—ó–∑–∏–Ω–≥ —Ä–µ–∂–∏–º—ñ"""
    train_losses = []
    
    for epoch in range(epochs):
        # –¢—Ä–µ–Ω—É–≤–∞–Ω–Ω—è
        model.train()
        train_loss = 0
        for data, _ in train_loader:
            data = data.to(device)
            
            if denoise_mode:
                # –î–µ–Ω–æ—ó–∑–∏–Ω–≥: –¥–æ–¥–∞—î–º–æ —à—É–º –Ω–∞ –≤—Ö—ñ–¥, —Ü—ñ–ª—å - —á–∏—Å—Ç–µ –∑–æ–±—Ä–∞–∂–µ–Ω–Ω—è
                noisy_data = add_noise(data, NOISE_FACTOR)
                optimizer.zero_grad()
                output = model(noisy_data)
                loss = criterion(output, data)  # –ü–æ—Ä—ñ–≤–Ω—é—î–º–æ –∑ —á–∏—Å—Ç–∏–º –∑–æ–±—Ä–∞–∂–µ–Ω–Ω—è–º
            else:
                # –ó–≤–∏—á–∞–π–Ω–∏–π –∞–≤—Ç–æ–µ–Ω–∫–æ–¥–µ—Ä
                optimizer.zero_grad()
                output = model(data)
                loss = criterion(output, data)
            
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        train_losses.append(train_loss)
        mode = "–î–ï–ù–û–á–ó–ò–ù–ì" if denoise_mode else "–ê–í–¢–û–ï–ù–ö–û–î–ï–†"
        print(f'Epoch {epoch+1}/{epochs} [{mode}]: Loss: {train_loss:.4f}')
    
    return train_losses

print("–ü–æ—á–∞—Ç–æ–∫ —Ç—Ä–µ–Ω—É–≤–∞–Ω–Ω—è –∞–≤—Ç–æ–µ–Ω–∫–æ–¥–µ—Ä–∞...")
train_losses = fast_train(model, train_loader, test_loader, denoise_mode=False)

# --- Cell 4: –î–ï–¢–ï–ö–¶–Ü–Ø –ê–ù–û–ú–ê–õ–Ü–ô ---
print("\n===  –î–ï–¢–ï–ö–¶–Ü–Ø –ê–ù–û–ú–ê–õ–Ü–ô ===")

# –í–∏–∑–Ω–∞—á–∞—î–º–æ –Ω–æ—Ä–º–∞–ª—å–Ω–∏–π –∫–ª–∞—Å (–Ω–∞–ø—Ä–∏–∫–ª–∞–¥, dog)
NORMAL_CLASS = class_names.index('dog') if 'dog' in class_names else 0
print(f"–ù–æ—Ä–º–∞–ª—å–Ω–∏–π –∫–ª–∞—Å: {class_names[NORMAL_CLASS]}")

# –†–æ–∑–¥—ñ–ª—è—î–º–æ —Ç–µ—Å—Ç–æ–≤—ñ –¥–∞–Ω—ñ –Ω–∞ –Ω–æ—Ä–º–∞–ª—å–Ω—ñ —Ç–∞ –∞–Ω–æ–º–∞–ª—å–Ω—ñ
test_normal_indices = []
test_anomalous_indices = []

for i in range(len(test_dataset)):
    img, label = test_dataset[i]
    if label == NORMAL_CLASS:
        test_normal_indices.append(i)
    else:
        test_anomalous_indices.append(i)

# –û–±–º–µ–∂—É—î–º–æ –∫—ñ–ª—å–∫—ñ—Å—Ç—å –¥–ª—è —à–≤–∏–¥–∫–æ—Å—Ç—ñ
test_normal_indices = test_normal_indices[:20]
test_anomalous_indices = test_anomalous_indices[:20]

test_normal = Subset(test_dataset, test_normal_indices)
test_anomalous = Subset(test_dataset, test_anomalous_indices)

print(f"–¢–µ—Å—Ç–æ–≤—ñ –¥–∞–Ω—ñ: {len(test_normal)} –Ω–æ—Ä–º–∞–ª—å–Ω–∏—Ö, {len(test_anomalous)} –∞–Ω–æ–º–∞–ª—å–Ω–∏—Ö")

# –§—É–Ω–∫—Ü—ñ—è –¥–ª—è –æ–±—á–∏—Å–ª–µ–Ω–Ω—è –ø–æ–º–∏–ª–æ–∫ —Ä–µ–∫–æ–Ω—Å—Ç—Ä—É–∫—Ü—ñ—ó
def fast_errors(model, dataset, device):
    """–û–±—á–∏—Å–ª—é—î –ø–æ–º–∏–ª–∫–∏ —Ä–µ–∫–æ–Ω—Å—Ç—Ä—É–∫—Ü—ñ—ó –¥–ª—è –≤–∏—è–≤–ª–µ–Ω–Ω—è –∞–Ω–æ–º–∞–ª—ñ–π"""
    model.eval()
    errors = []
    with torch.no_grad():
        loader = DataLoader(dataset, batch_size=32)
        for images, _ in loader:
            images = images.to(device)
            reconstructed = model(images)
            # MSE –ø–æ–º–∏–ª–∫–∞ –º—ñ–∂ –æ—Ä–∏–≥—ñ–Ω–∞–ª–æ–º —Ç–∞ —Ä–µ–∫–æ–Ω—Å—Ç—Ä—É–∫—Ü—ñ—î—é
            error = torch.mean((images - reconstructed) ** 2, dim=[1,2,3])
            errors.extend(error.cpu().numpy())
    return np.array(errors)

# –û–±—á–∏—Å–ª—é—î–º–æ –ø–æ–º–∏–ª–∫–∏ –¥–ª—è –Ω–æ—Ä–º–∞–ª—å–Ω–∏—Ö —Ç–∞ –∞–Ω–æ–º–∞–ª—å–Ω–∏—Ö –¥–∞–Ω–∏—Ö
normal_errors = fast_errors(model, test_normal, device)
anomalous_errors = fast_errors(model, test_anomalous, device)

# –í—Å—Ç–∞–Ω–æ–≤–ª—é—î–º–æ –ø–æ—Ä—ñ–≥ –¥–ª—è –¥–µ—Ç–µ–∫—Ü—ñ—ó –∞–Ω–æ–º–∞–ª—ñ–π (95-–π –ø—Ä–æ—Ü–µ–Ω—Ç–∏–ª—å –Ω–æ—Ä–º–∞–ª—å–Ω–∏—Ö –ø–æ–º–∏–ª–æ–∫)
threshold = np.percentile(normal_errors, 95)
detection_rate = np.mean(anomalous_errors > threshold) * 100
false_positive = np.mean(normal_errors > threshold) * 100

print(f" –†–ï–ó–£–õ–¨–¢–ê–¢–ò –î–ï–¢–ï–ö–¶–Ü–á –ê–ù–û–ú–ê–õ–Ü–ô:")
print(f"   - –ü–æ—Ä—ñ–≥: {threshold:.4f}")
print(f"   - –í–∏—è–≤–ª–µ–Ω–æ –∞–Ω–æ–º–∞–ª—ñ–π: {detection_rate:.1f}%")
print(f"   - –ü–æ–º–∏–ª–∫–∞ (FPR): {false_positive:.1f}%")

# --- Cell 5: –î–ï–ù–û–á–ó–ò–ù–ì ---
print("\n=== üßπ –¢–ï–°–¢–£–í–ê–ù–ù–Ø –î–ï–ù–û–á–ó–ò–ù–ì–£ ===")

def test_denoising(model, test_samples=5):
    """–¢–µ—Å—Ç—É—î–º–æ –∑–¥–∞—Ç–Ω—ñ—Å—Ç—å –º–æ–¥–µ–ª—ñ –≤–∏–¥–∞–ª—è—Ç–∏ —à—É–º"""
    model.eval()
    
    # –ë–µ—Ä–µ–º–æ –¥–µ–∫—ñ–ª—å–∫–∞ —Ç–µ—Å—Ç–æ–≤–∏—Ö –∑–æ–±—Ä–∞–∂–µ–Ω—å
    denoise_indices = test_normal_indices[:test_samples] + test_anomalous_indices[:test_samples]
    denoise_dataset = Subset(test_dataset, denoise_indices)
    denoise_loader = DataLoader(denoise_dataset, batch_size=test_samples*2)
    
    with torch.no_grad():
        for clean_imgs, labels in denoise_loader:
            clean_imgs = clean_imgs.to(device)
            
            # –î–æ–¥–∞—î–º–æ —à—É–º
            noisy_imgs = add_noise(clean_imgs, NOISE_FACTOR)
            
            # –í—ñ–¥–Ω–æ–≤–ª—é—î–º–æ –∑ —à—É–º–Ω–∏—Ö –∑–æ–±—Ä–∞–∂–µ–Ω—å
            denoised_imgs = model(noisy_imgs)
            
            # –û–±—á–∏—Å–ª—é—î–º–æ —è–∫—ñ—Å—Ç—å –¥–µ–Ω–æ—ó–∑–∏–Ω–≥—É
            mse_clean = criterion(clean_imgs, clean_imgs).item()
            mse_noisy = criterion(noisy_imgs, clean_imgs).item()
            mse_denoised = criterion(denoised_imgs, clean_imgs).item()
            
            improvement = ((mse_noisy - mse_denoised) / mse_noisy) * 100
            
            print(f" –Ø–∫—ñ—Å—Ç—å –¥–µ–Ω–æ—ó–∑–∏–Ω–≥—É:")
            print(f"   - MSE (—á–∏—Å—Ç—ñ): {mse_clean:.4f}")
            print(f"   - MSE (—à—É–º–Ω—ñ): {mse_noisy:.4f}")
            print(f"   - MSE (–≤—ñ–¥–Ω–æ–≤–ª–µ–Ω—ñ): {mse_denoised:.4f}")
            print(f"   - –ü–æ–∫—Ä–∞—â–µ–Ω–Ω—è: {improvement:.1f}%")
            
            return clean_imgs, noisy_imgs, denoised_imgs, labels

# –¢–µ—Å—Ç—É—î–º–æ –¥–µ–Ω–æ—ó–∑–∏–Ω–≥
clean_imgs, noisy_imgs, denoised_imgs, labels = test_denoising(model)

# --- Cell 6: –í–Ü–ó–£–ê–õ–Ü–ó–ê–¶–Ü–Ø –†–ï–ó–£–õ–¨–¢–ê–¢–Ü–í ---
plt.figure(figsize=(15, 4))

# –ì—Ä–∞—Ñ—ñ–∫ —Ç—Ä–µ–Ω—É–≤–∞–Ω–Ω—è
plt.subplot(1, 4, 1)
plt.plot(train_losses)
plt.title('–í—Ç—Ä–∞—Ç–∏ –ø—Ä–∏ —Ç—Ä–µ–Ω—É–≤–∞–Ω–Ω—ñ')
plt.xlabel('–ï–ø–æ—Ö–∞')
plt.ylabel('MSE Loss')
plt.grid(True, alpha=0.3)

# –†–æ–∑–ø–æ–¥—ñ–ª –ø–æ–º–∏–ª–æ–∫ –¥–ª—è –¥–µ—Ç–µ–∫—Ü—ñ—ó –∞–Ω–æ–º–∞–ª—ñ–π
plt.subplot(1, 4, 2)
plt.hist(normal_errors, alpha=0.7, label='–ù–æ—Ä–º–∞–ª—å–Ω—ñ', bins=15, color='green')
plt.hist(anomalous_errors, alpha=0.7, label='–ê–Ω–æ–º–∞–ª—å–Ω—ñ', bins=15, color='red')
plt.axvline(threshold, color='black', linestyle='--', label=f'–ü–æ—Ä—ñ–≥: {threshold:.3f}')
plt.legend()
plt.title('–†–æ–∑–ø–æ–¥—ñ–ª –ø–æ–º–∏–ª–æ–∫ —Ä–µ–∫–æ–Ω—Å—Ç—Ä—É–∫—Ü—ñ—ó')
plt.xlabel('–ü–æ–º–∏–ª–∫–∞ MSE')
plt.grid(True, alpha=0.3)

# –ü—Ä–∏–∫–ª–∞–¥ —Ä–µ–∫–æ–Ω—Å—Ç—Ä—É–∫—Ü—ñ—ó
plt.subplot(1, 4, 3)
sample_idx = 0
sample_img, sample_label = test_dataset[sample_idx]
sample_img = sample_img.unsqueeze(0).to(device)
with torch.no_grad():
    reconstructed = model(sample_img)

plt.imshow(sample_img[0].cpu().permute(1, 2, 0))
plt.title(f'–û—Ä–∏–≥—ñ–Ω–∞–ª: {class_names[sample_label]}')
plt.axis('off')

# –ü—Ä–∏–∫–ª–∞–¥ –¥–µ–Ω–æ—ó–∑–∏–Ω–≥—É
plt.subplot(1, 4, 4)
if clean_imgs is not None:
    # –ü–æ–∫–∞–∑—É—î–º–æ –ø–µ—Ä—à–µ –≤—ñ–¥–Ω–æ–≤–ª–µ–Ω–µ –∑–æ–±—Ä–∞–∂–µ–Ω–Ω—è
    plt.imshow(denoised_imgs[0].cpu().permute(1, 2, 0))
    plt.title(f'–í—ñ–¥–Ω–æ–≤–ª–µ–Ω–µ (–¥–µ–Ω–æ—ó–∑–∏–Ω–≥)')
    plt.axis('off')

plt.tight_layout()
plt.show()

# --- Cell 7: –í–Ü–ó–£–ê–õ–Ü–ó–ê–¶–Ü–Ø –õ–ê–¢–ï–ù–¢–ù–û–ì–û –ü–†–û–°–¢–û–†–£ ---
print("\n===  –í–Ü–ó–£–ê–õ–Ü–ó–ê–¶–Ü–Ø –õ–ê–¢–ï–ù–¢–ù–û–ì–û –ü–†–û–°–¢–û–†–£ ===")

def fast_latent_vectors(model, dataset, device, max_samples=50):
    """–í–∏—Ç—è–≥—É—î –ª–∞—Ç–µ–Ω—Ç–Ω—ñ –≤–µ–∫—Ç–æ—Ä–∏ –¥–ª—è –≤—ñ–∑—É–∞–ª—ñ–∑–∞—Ü—ñ—ó"""
    model.eval()
    latent_vectors = []
    labels = []
    
    with torch.no_grad():
        loader = DataLoader(dataset, batch_size=32)
        for images, batch_labels in loader:
            images = images.to(device)
            latent = model.encoder(images)  # –°—Ç–∏—Å–Ω–µ–Ω–Ω—è –¥–æ –ª–∞—Ç–µ–Ω—Ç–Ω–æ–≥–æ –ø—Ä–æ—Å—Ç–æ—Ä—É
            latent_vectors.extend(latent.cpu().numpy())
            labels.extend(batch_labels.numpy())
            
            if len(latent_vectors) >= max_samples:
                break
    
    return np.array(latent_vectors)[:max_samples], np.array(labels)[:max_samples]

# –í–∏—Ç—è–≥—É—î–º–æ –ª–∞—Ç–µ–Ω—Ç–Ω—ñ –ø—Ä–µ–¥—Å—Ç–∞–≤–ª–µ–Ω–Ω—è
latent_vectors, labels = fast_latent_vectors(model, test_dataset, device)

print(f"üìä –í–∏—Ç—è–≥–Ω—É—Ç–æ {len(latent_vectors)} –ª–∞—Ç–µ–Ω—Ç–Ω–∏—Ö –≤–µ–∫—Ç–æ—Ä—ñ–≤ —Ä–æ–∑–º—ñ—Ä–Ω—ñ—Å—Ç—é {LATENT_DIM}")

# t-SNE –¥–ª—è –≤—ñ–∑—É–∞–ª—ñ–∑–∞—Ü—ñ—ó –ª–∞—Ç–µ–Ω—Ç–Ω–æ–≥–æ –ø—Ä–æ—Å—Ç–æ—Ä—É (16D -> 2D)
if len(latent_vectors) > 10:
    print("üîç –ó–∞—Å—Ç–æ—Å–æ–≤—É—î–º–æ t-SNE –¥–ª—è –≤—ñ–∑—É–∞–ª—ñ–∑–∞—Ü—ñ—ó...")
    tsne = TSNE(n_components=2, random_state=42, perplexity=min(10, len(latent_vectors)-1))
    latent_2d = tsne.fit_transform(latent_vectors)
    
    plt.figure(figsize=(12, 5))
    
    # –í—ñ–∑—É–∞–ª—ñ–∑–∞—Ü—ñ—è –∑–∞ –∫–ª–∞—Å–∞–º–∏
    plt.subplot(1, 2, 1)
    colors = ['red', 'blue', 'green', 'orange', 'purple']
    for class_idx in range(len(class_names)):
        mask = labels == class_idx
        if np.sum(mask) > 0:
            plt.scatter(latent_2d[mask, 0], latent_2d[mask, 1], 
                       label=class_names[class_idx], alpha=0.7, s=30)
    plt.legend()
    plt.title('–õ–∞—Ç–µ–Ω—Ç–Ω–∏–π –ø—Ä–æ—Å—Ç—ñ—Ä –∑–∞ –∫–ª–∞—Å–∞–º–∏')
    plt.xlabel('t-SNE –∫–æ–º–ø–æ–Ω–µ–Ω—Ç–∞ 1')
    plt.ylabel('t-SNE –∫–æ–º–ø–æ–Ω–µ–Ω—Ç–∞ 2')
    plt.grid(True, alpha=0.3)
    
    # –í—ñ–∑—É–∞–ª—ñ–∑–∞—Ü—ñ—è –Ω–æ—Ä–º–∞–ª—å–Ω—ñ vs –∞–Ω–æ–º–∞–ª—å–Ω—ñ
    plt.subplot(1, 2, 2)
    is_normal = (labels == NORMAL_CLASS).astype(int)
    plt.scatter(latent_2d[is_normal==1, 0], latent_2d[is_normal==1, 1], 
               c='green', label=f'–ù–æ—Ä–º–∞–ª—å–Ω—ñ ({class_names[NORMAL_CLASS]})', alpha=0.7, s=30)
    plt.scatter(latent_2d[is_normal==0, 0], latent_2d[is_normal==0, 1], 
               c='red', label='–ê–Ω–æ–º–∞–ª—å–Ω—ñ', alpha=0.7, s=30)
    plt.legend()
    plt.title('–ü–æ–¥—ñ–ª –Ω–æ—Ä–º–∞–ª—å–Ω–∏—Ö/–∞–Ω–æ–º–∞–ª—å–Ω–∏—Ö —É –ª–∞—Ç–µ–Ω—Ç–Ω–æ–º—É –ø—Ä–æ—Å—Ç–æ—Ä—ñ')
    plt.xlabel('t-SNE –∫–æ–º–ø–æ–Ω–µ–Ω—Ç–∞ 1')
    plt.ylabel('t-SNE –∫–æ–º–ø–æ–Ω–µ–Ω—Ç–∞ 2')
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
else:
    print(" –ù–µ–¥–æ—Å—Ç–∞—Ç–Ω—å–æ –¥–∞–Ω–∏—Ö –¥–ª—è t-SNE")

# --- Cell 8: –î–û–î–ê–¢–ö–û–í–ê –í–Ü–ó–£–ê–õ–Ü–ó–ê–¶–Ü–Ø –î–ï–ù–û–á–ó–ò–ù–ì–£ ---
print("\n=== üñºÔ∏è –í–Ü–ó–£–ê–õ–Ü–ó–ê–¶–Ü–Ø –†–ï–ó–£–õ–¨–¢–ê–¢–Ü–í –î–ï–ù–û–á–ó–ò–ù–ì–£ ===")

if clean_imgs is not None:
    plt.figure(figsize=(12, 4))
    
    # –ü–æ–∫–∞–∑—É—î–º–æ –¥–µ–∫—ñ–ª—å–∫–∞ –ø—Ä–∏–∫–ª–∞–¥—ñ–≤ –¥–µ–Ω–æ—ó–∑–∏–Ω–≥—É
    num_examples = min(3, len(clean_imgs))
    
    for i in range(num_examples):
        # –û—Ä–∏–≥—ñ–Ω–∞–ª—å–Ω–µ –∑–æ–±—Ä–∞–∂–µ–Ω–Ω—è
        plt.subplot(3, num_examples, i + 1)
        plt.imshow(clean_imgs[i].cpu().permute(1, 2, 0))
        plt.title(f'–û—Ä–∏–≥—ñ–Ω–∞–ª\n{class_names[labels[i]]}')
        plt.axis('off')
        
        # –ó–∞—à—É–º–ª–µ–Ω–µ –∑–æ–±—Ä–∞–∂–µ–Ω–Ω—è
        plt.subplot(3, num_examples, i + 1 + num_examples)
        plt.imshow(noisy_imgs[i].cpu().permute(1, 2, 0))
        plt.title('–ó–∞—à—É–º–ª–µ–Ω–µ')
        plt.axis('off')
        
        # –í—ñ–¥–Ω–æ–≤–ª–µ–Ω–µ –∑–æ–±—Ä–∞–∂–µ–Ω–Ω—è
        plt.subplot(3, num_examples, i + 1 + 2*num_examples)
        plt.imshow(denoised_imgs[i].cpu().permute(1, 2, 0))
        plt.title('–í—ñ–¥–Ω–æ–≤–ª–µ–Ω–µ')
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

print("\n –ï–ö–°–ü–ï–†–ò–ú–ï–ù–¢ –ó–ê–í–ï–†–®–ï–ù–û!")
print(f" –í–∏–∫–æ—Ä–∏—Å—Ç–∞–Ω–æ: {SUBSET_SIZE} —Ç—Ä–µ–Ω—É–≤–∞–ª—å–Ω–∏—Ö —Ç–∞ {TEST_SIZE} —Ç–µ—Å—Ç–æ–≤–∏—Ö –∑–æ–±—Ä–∞–∂–µ–Ω—å")
print(f" –ö—ñ–ª—å–∫—ñ—Å—Ç—å –µ–ø–æ—Ö: {EPOCHS}")
print(f" –ö–ª–∞—Å–∏: {class_names}")
print(f"–†–ï–ê–õ–Ü–ó–û–í–ê–ù–Ü –§–£–ù–ö–¶–Ü–á:")
print(f"   1. –î–µ—Ç–µ–∫—Ü—ñ—è –∞–Ω–æ–º–∞–ª—ñ–π (–ø–æ—Ä—ñ–≥: {threshold:.4f})")
print(f"   2. –ó–º–µ–Ω—à–µ–Ω–Ω—è —Ä–æ–∑–º—ñ—Ä–Ω–æ—Å—Ç—ñ —Ç–∞ –≤—ñ–∑—É–∞–ª—ñ–∑–∞—Ü—ñ—è ({LATENT_DIM}D -> 2D)")
print(f"   3. –î–µ–Ω–æ—ó–∑–∏–Ω–≥ –∑–æ–±—Ä–∞–∂–µ–Ω—å (—à—É–º: {NOISE_FACTOR})")

: 