In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import os
import sys

FAST_MODE = True
SUBSET_SIZE = 14500  # –¢—ñ–ª—å–∫–∏ 500 –∑–æ–±—Ä–∞–∂–µ–Ω—å –¥–ª—è —Ç—Ä–µ–Ω—É–≤–∞–Ω–Ω—è!
TEST_SIZE = 1500    # –¢—ñ–ª—å–∫–∏ 100 –¥–ª—è —Ç–µ—Å—Ç—É
BATCH_SIZE = 32
EPOCHS = 8         # –î—É–∂–µ –º–∞–ª–æ –µ–ø–æ—Ö!
LATENT_DIM = 16    # –î—É–∂–µ –º–∞–ª–∏–π latent space

# --- Cell 1: –ê–†–•–Ü–¢–ï–ö–¢–£–†–ê ---
class FastAutoencoder(nn.Module):
    def __init__(self, latent_dim=16):
        super(FastAutoencoder, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 8, kernel_size=3, stride=2, padding=1),  # 64x64 -> 32x32
            nn.ReLU(),
            nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1), # 32x32 -> 16x16
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(16 * 16 * 16, latent_dim)
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 16 * 16 * 16),
            nn.ReLU(),
            nn.Unflatten(1, (16, 16, 16)),
            nn.ConvTranspose2d(16, 8, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(8, 3, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)

# --- Cell 2: –ó–ê–í–ê–ù–¢–ê–ñ–ï–ù–ù–Ø (–í–ò–ü–†–ê–í–õ–ï–ù–ê –í–ï–†–°–Ü–Ø) ---
def load_tiny_dataset(data_path, img_size=64):
    """–ó–∞–≤–∞–Ω—Ç–∞–∂—É—î–º–æ –≤–∏–±—ñ—Ä–∫—É –¥–∞–Ω–∏—Ö"""
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
    ])
    
    try:
        # –ü–µ—Ä–µ–≤—ñ—Ä–∫–∞ —á–∏ —ñ—Å–Ω—É—î —à–ª—è—Ö
        if not os.path.exists(data_path):
            raise FileNotFoundError(f"–®–ª—è—Ö {data_path} –Ω–µ —ñ—Å–Ω—É—î")
            
        full_dataset = ImageFolder(root=data_path, transform=transform)
        print(f" –£—Å–ø—ñ—à–Ω–æ –∑–∞–≤–∞–Ω—Ç–∞–∂–µ–Ω–æ: {len(full_dataset)} –∑–æ–±—Ä–∞–∂–µ–Ω—å")
        print(f"   –ö–ª–∞—Å–∏: {full_dataset.classes}")
        
        # –ë–µ—Ä–µ–º–æ –ª–∏—à–µ –Ω–µ–≤–µ–ª–∏–∫—É –≤–∏–±—ñ—Ä–∫—É –¥–ª—è —à–≤–∏–¥–∫–æ—Å—Ç—ñ
        total_samples = min(SUBSET_SIZE + TEST_SIZE, len(full_dataset))
        indices = torch.randperm(len(full_dataset))[:total_samples]
        
        train_size = SUBSET_SIZE
        train_indices = indices[:train_size]
        test_indices = indices[train_size:train_size + TEST_SIZE]
        
        train_dataset = Subset(full_dataset, train_indices)
        test_dataset = Subset(full_dataset, test_indices)
        
        print(f"üìä –í–∏–∫–æ—Ä–∏—Å—Ç–æ–≤—É—î–º–æ: {len(train_dataset)} —Ç—Ä–µ–Ω—É–≤–∞–ª—å–Ω–∏—Ö, {len(test_dataset)} —Ç–µ—Å—Ç–æ–≤–∏—Ö")
        
        return train_dataset, test_dataset, full_dataset.classes
        
    except Exception as e:
        print(f"–ü–æ–º–∏–ª–∫–∞ –∑–∞–≤–∞–Ω—Ç–∞–∂–µ–Ω–Ω—è: {e}")
        print("–°—Ç–≤–æ—Ä—é—î–º–æ –¥–µ–º–æ-–¥–∞–Ω—ñ...")
        from torchvision.datasets import FakeData
        train_dataset = FakeData(size=SUBSET_SIZE, image_size=(3, 64, 64), num_classes=3, transform=transforms.ToTensor())
        test_dataset = FakeData(size=TEST_SIZE, image_size=(3, 64, 64), num_classes=3, transform=transforms.ToTensor())
        class_names = ['cat', 'dog', 'wild']
        
        print(f"üìä –î–µ–º–æ-–¥–∞–Ω—ñ: {len(train_dataset)} —Ç—Ä–µ–Ω—É–≤–∞–ª—å–Ω–∏—Ö, {len(test_dataset)} —Ç–µ—Å—Ç–æ–≤–∏—Ö")
        
        return train_dataset, test_dataset, class_names

# –ó–∞–≤–∞–Ω—Ç–∞–∂—É—î–º–æ –¥–∞–Ω—ñ
DATA_PATH = "D:/Kotopes/Kotopes/data"
print(f" –ó–∞–≤–∞–Ω—Ç–∞–∂–µ–Ω–Ω—è –¥–∞–Ω–∏—Ö –∑: {DATA_PATH}")
train_dataset, test_dataset, class_names = load_tiny_dataset(DATA_PATH)

# --- Cell 3: –¢–†–ï–ù–£–í–ê–ù–ù–Ø ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f" –í–∏–∫–æ—Ä–∏—Å—Ç–æ–≤—É—î—Ç—å—Å—è –ø—Ä–∏—Å—Ç—Ä—ñ–π: {device}")

model = FastAutoencoder(latent_dim=LATENT_DIM).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

def fast_train(model, train_loader, test_loader, epochs=EPOCHS):
    train_losses = []
    
    for epoch in range(epochs):
        # –¢—Ä–µ–Ω—É–≤–∞–Ω–Ω—è
        model.train()
        train_loss = 0
        for data, _ in train_loader:
            data = data.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, data)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        train_losses.append(train_loss)
        print(f'Epoch {epoch+1}/{epochs}: Loss: {train_loss:.4f}')
    
    return train_losses

print("–ü–æ—á–∞—Ç–æ–∫ —à–≤–∏–¥–∫–æ–≥–æ —Ç—Ä–µ–Ω—É–≤–∞–Ω–Ω—è...")
train_losses = fast_train(model, train_loader, test_loader)

# --- Cell 4: –®–í–ò–î–ö–ò–ô –ê–ù–ê–õ–Ü–ó (–í–ò–ü–†–ê–í–õ–ï–ù–ê –í–ï–†–°–Ü–Ø) ---
print("\n=== –ê–ù–ê–õ–Ü–ó ===")

# –í–∏–∑–Ω–∞—á–∞—î–º–æ –Ω–æ—Ä–º–∞–ª—å–Ω–∏–π –∫–ª–∞—Å
NORMAL_CLASS = class_names.index('cat') if 'cat' in class_names else 0
print(f"–ù–æ—Ä–º–∞–ª—å–Ω–∏–π –∫–ª–∞—Å: {class_names[NORMAL_CLASS]}")

# –§—ñ–ª—å—Ç—Ä—É—î–º–æ —Ç–µ—Å—Ç–æ–≤—ñ –¥–∞–Ω—ñ (–í–ò–ü–†–ê–í–õ–ï–ù–ê –ß–ê–°–¢–ò–ù–ê)
test_normal_indices = []
test_anomalous_indices = []

for i in range(len(test_dataset)):
    img, label = test_dataset[i]  # –í–∏–ø—Ä–∞–≤–ª–µ–Ω–æ - –±–µ–∑ –∑–∞–π–≤–æ–≥–æ unpacking
    if label == NORMAL_CLASS:
        test_normal_indices.append(i)
    else:
        test_anomalous_indices.append(i)

# –û–±–º–µ–∂—É—î–º–æ –∫—ñ–ª—å–∫—ñ—Å—Ç—å –¥–ª—è —à–≤–∏–¥–∫–æ—Å—Ç—ñ
test_normal_indices = test_normal_indices[:20]
test_anomalous_indices = test_anomalous_indices[:20]

test_normal = Subset(test_dataset, test_normal_indices)
test_anomalous = Subset(test_dataset, test_anomalous_indices)

print(f" –¢–µ—Å—Ç–æ–≤—ñ –¥–∞–Ω—ñ: {len(test_normal)} –Ω–æ—Ä–º–∞–ª—å–Ω–∏—Ö, {len(test_anomalous)} –∞–Ω–æ–º–∞–ª—å–Ω–∏—Ö")

# –§—É–Ω–∫—Ü—ñ—è –¥–ª—è —à–≤–∏–¥–∫–æ–≥–æ –æ–±—á–∏—Å–ª–µ–Ω–Ω—è –ø–æ–º–∏–ª–æ–∫
def fast_errors(model, dataset, device):
    model.eval()
    errors = []
    with torch.no_grad():
        loader = DataLoader(dataset, batch_size=32)
        for images, _ in loader:
            images = images.to(device)
            reconstructed = model(images)
            error = torch.mean((images - reconstructed) ** 2, dim=[1,2,3])
            errors.extend(error.cpu().numpy())
    return np.array(errors)

normal_errors = fast_errors(model, test_normal, device)
anomalous_errors = fast_errors(model, test_anomalous, device)

threshold = np.percentile(normal_errors, 95)
detection_rate = np.mean(anomalous_errors > threshold) * 100
false_positive = np.mean(normal_errors > threshold) * 100

print(f"–†–ï–ó–£–õ–¨–¢–ê–¢–ò:")
print(f"   - Threshold: {threshold:.4f}")
print(f"   - –í–∏—è–≤–ª–µ–Ω–æ –∞–Ω–æ–º–∞–ª—ñ–π: {detection_rate:.1f}%")
print(f"   - –ü–æ–º–∏–ª–∫–∞ (FPR): {false_positive:.1f}%")

# --- Cell 5: –í–Ü–ó–£–ê–õ–Ü–ó–ê–¶–Ü–Ø ---
plt.figure(figsize=(12, 3))

# –ì—Ä–∞—Ñ—ñ–∫ —Ç—Ä–µ–Ω—É–≤–∞–Ω–Ω—è
plt.subplot(1, 3, 1)
plt.plot(train_losses)
plt.title('–í—Ç—Ä–∞—Ç–∏ –ø—Ä–∏ —Ç—Ä–µ–Ω—É–≤–∞–Ω–Ω—ñ')
plt.xlabel('–ï–ø–æ—Ö–∞')
plt.grid(True, alpha=0.3)

# –†–æ–∑–ø–æ–¥—ñ–ª –ø–æ–º–∏–ª–æ–∫
plt.subplot(1, 3, 2)
plt.hist(normal_errors, alpha=0.7, label='–ù–æ—Ä–º–∞–ª—å–Ω—ñ', bins=15, color='green')
plt.hist(anomalous_errors, alpha=0.7, label='–ê–Ω–æ–º–∞–ª—å–Ω—ñ', bins=15, color='red')
plt.axvline(threshold, color='black', linestyle='--', label=f'Threshold: {threshold:.3f}')
plt.legend()
plt.title('–†–æ–∑–ø–æ–¥—ñ–ª –ø–æ–º–∏–ª–æ–∫')
plt.grid(True, alpha=0.3)

# –ü—Ä–∏–∫–ª–∞–¥–∏
plt.subplot(1, 3, 3)
# –ë–µ—Ä–µ–º–æ –æ–¥–∏–Ω –ø—Ä–∏–∫–ª–∞–¥ –¥–ª—è –¥–µ–º–æ–Ω—Å—Ç—Ä–∞—Ü—ñ—ó
sample_idx = 0
sample_img, sample_label = test_dataset[sample_idx]  # –í–∏–ø—Ä–∞–≤–ª–µ–Ω–æ
sample_img = sample_img.unsqueeze(0).to(device)
with torch.no_grad():
    reconstructed = model(sample_img)

plt.imshow(sample_img[0].cpu().permute(1, 2, 0))
plt.title(f'–ü—Ä–∏–∫–ª–∞–¥: {class_names[sample_label]}')
plt.axis('off')

plt.tight_layout()
plt.show()

# --- Cell 6: –õ–ê–¢–ï–ù–¢–ù–ò–ô –ü–†–û–°–¢–Ü–† ---
print("\n=== –õ–ê–¢–ï–ù–¢–ù–ò–ô –ü–†–û–°–¢–Ü–† ===")

def fast_latent_vectors(model, dataset, device, max_samples=50):
    model.eval()
    latent_vectors = []
    labels = []
    
    with torch.no_grad():
        loader = DataLoader(dataset, batch_size=32)
        for images, batch_labels in loader:
            images = images.to(device)
            latent = model.encoder(images)
            latent_vectors.extend(latent.cpu().numpy())
            labels.extend(batch_labels.numpy())
            
            if len(latent_vectors) >= max_samples:
                break
    
    return np.array(latent_vectors)[:max_samples], np.array(labels)[:max_samples]

latent_vectors, labels = fast_latent_vectors(model, test_dataset, device)

print(f" –í–∏—Ç—è–≥–Ω—É—Ç–æ {len(latent_vectors)} –ª–∞—Ç–µ–Ω—Ç–Ω–∏—Ö –≤–µ–∫—Ç–æ—Ä—ñ–≤")

# t-SNE –¥–ª—è –ª–∞—Ç–µ–Ω—Ç–Ω–æ–≥–æ –ø—Ä–æ—Å—Ç–æ—Ä—É
if len(latent_vectors) > 10:
    print("üîç –ó–∞—Å—Ç–æ—Å–æ–≤—É—î–º–æ t-SNE...")
    tsne = TSNE(n_components=2, random_state=42, perplexity=min(10, len(latent_vectors)-1))
    latent_2d = tsne.fit_transform(latent_vectors)
    
    plt.figure(figsize=(8, 4))
    
    # –ó–∞ –∫–ª–∞—Å–∞–º–∏
    plt.subplot(1, 2, 1)
    colors = ['red', 'blue', 'green', 'orange', 'purple']
    for class_idx in range(len(class_names)):
        mask = labels == class_idx
        if np.sum(mask) > 0:
            plt.scatter(latent_2d[mask, 0], latent_2d[mask, 1], 
                       label=class_names[class_idx], alpha=0.7, s=30)
    plt.legend()
    plt.title('–õ–∞—Ç–µ–Ω—Ç–Ω–∏–π –ø—Ä–æ—Å—Ç—ñ—Ä –∑–∞ –∫–ª–∞—Å–∞–º–∏')
    plt.grid(True, alpha=0.3)
    
    # –ù–æ—Ä–º–∞–ª—å–Ω—ñ vs –∞–Ω–æ–º–∞–ª—å–Ω—ñ
    plt.subplot(1, 2, 2)
    is_normal = (labels == NORMAL_CLASS).astype(int)
    plt.scatter(latent_2d[is_normal==1, 0], latent_2d[is_normal==1, 1], 
               c='green', label=f'–ù–æ—Ä–º–∞–ª—å–Ω—ñ ({class_names[NORMAL_CLASS]})', alpha=0.7, s=30)
    plt.scatter(latent_2d[is_normal==0, 0], latent_2d[is_normal==0, 1], 
               c='red', label='–ê–Ω–æ–º–∞–ª—å–Ω—ñ', alpha=0.7, s=30)
    plt.legend()
    plt.title('–ü–æ–¥—ñ–ª –Ω–æ—Ä–º–∞–ª—å–Ω–∏—Ö/–∞–Ω–æ–º–∞–ª—å–Ω–∏—Ö')
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
else:
    print(" –ù–µ–¥–æ—Å—Ç–∞—Ç–Ω—å–æ –¥–∞–Ω–∏—Ö –¥–ª—è t-SNE")

print("\n–ï–ö–°–ü–ï–†–ò–ú–ï–ù–¢ –ó–ê–í–ï–†–®–ï–ù–û!")
print(f" –í–∏–∫–æ—Ä–∏—Å—Ç–∞–Ω–æ: {SUBSET_SIZE} —Ç—Ä–µ–Ω—É–≤–∞–ª—å–Ω–∏—Ö —Ç–∞ {TEST_SIZE} —Ç–µ—Å—Ç–æ–≤–∏—Ö –∑–æ–±—Ä–∞–∂–µ–Ω—å")
print(f" –ö—ñ–ª—å–∫—ñ—Å—Ç—å –µ–ø–æ—Ö: {EPOCHS}")
print(f" –ö–ª–∞—Å–∏: {class_names}")

 –ó–∞–≤–∞–Ω—Ç–∞–∂–µ–Ω–Ω—è –¥–∞–Ω–∏—Ö –∑: D:/Kotopes/Kotopes/data
 –£—Å–ø—ñ—à–Ω–æ –∑–∞–≤–∞–Ω—Ç–∞–∂–µ–Ω–æ: 16130 –∑–æ–±—Ä–∞–∂–µ–Ω—å
   –ö–ª–∞—Å–∏: ['train', 'val']
üìä –í–∏–∫–æ—Ä–∏—Å—Ç–æ–≤—É—î–º–æ: 14500 —Ç—Ä–µ–Ω—É–≤–∞–ª—å–Ω–∏—Ö, 1500 —Ç–µ—Å—Ç–æ–≤–∏—Ö
 –í–∏–∫–æ—Ä–∏—Å—Ç–æ–≤—É—î—Ç—å—Å—è –ø—Ä–∏—Å—Ç—Ä—ñ–π: cpu
–ü–æ—á–∞—Ç–æ–∫ —à–≤–∏–¥–∫–æ–≥–æ —Ç—Ä–µ–Ω—É–≤–∞–Ω–Ω—è...
Epoch 1/8: Loss: 0.0341
