In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from torchvision.models import resnet18
from sklearn.ensemble import IsolationForest
from scipy.spatial.distance import mahalanobis
from sklearn.preprocessing import StandardScaler

# Define a feature extractor using a pretrained model
class FeatureExtractor(nn.Module):
    def __init__(self, pretrained_model):
        super(FeatureExtractor, self).__init__()
        self.features = nn.Sequential(*list(pretrained_model.children())[:-1])

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return x

# Define the Auto-encoder with a feature extractor
class PretrainedAutoEncoder(nn.Module):
    def __init__(self, feature_extractor):
        super(PretrainedAutoEncoder, self).__init__()
        self.encoder = feature_extractor
        self.decoder = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 32*32*3),
            nn.Sigmoid()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded.view(x.size(0), 3, 32, 32), encoded

# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4)
])
cifar10_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(cifar10_train, batch_size=256, shuffle=True)

# Initialize models and optimizer
pretrained_model = resnet18(pretrained=True)
feature_extractor = FeatureExtractor(pretrained_model).cuda()
autoencoder = PretrainedAutoEncoder(feature_extractor).cuda()
criterion = nn.MSELoss()
optimizer = optim.Adam(autoencoder.parameters(), lr=1e-3)

# Training loop
for epoch in range(50):
    for img, _ in train_loader:
        img = img.cuda()
        optimizer.zero_grad()
        decoded, encoded = autoencoder(img)
        loss = criterion(decoded, img)
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

# Extract features and reconstruction errors
features = []
reconstruction_errors = []
for img, _ in train_loader:
    img = img.cuda()
    with torch.no_grad():
        decoded, encoded = autoencoder(img)
        loss = torch.mean((decoded - img) ** 2, dim=[1, 2, 3])
        reconstruction_errors.extend(loss.cpu().numpy())
        features.extend(encoded.cpu().numpy())

features = np.vstack(features)

# Mahalanobis Distance
def calculate_mahalanobis_distances(features):
    mean_vec = np.mean(features, axis=0)
    cov_matrix = np.cov(features, rowvar=False)
    inv_cov_matrix = np.linalg.inv(cov_matrix)
    distances = [mahalanobis(f, mean_vec, inv_cov_matrix) for f in features]
    return np.array(distances)

mahalanobis_distances = calculate_mahalanobis_distances(features)

# Isolation Forest
iso_forest = IsolationForest(contamination=0.01)
iso_labels = iso_forest.fit_predict(features)
anomaly_indices_iso_forest = np.where(iso_labels == -1)[0]

# Combine and rank anomalies
anomaly_scores = reconstruction_errors + mahalanobis_distances
top_anomaly_indices = np.argsort(anomaly_scores)[-100:]

anomalous_images = [cifar10_train[i][0].numpy().transpose(1, 2, 0) for i in top_anomaly_indices]

# Visualize anomalies
fig, axes = plt.subplots(10, 10, figsize=(10, 10))
for i, ax in enumerate(axes.flat):
    ax.imshow((anomalous_images[i] * 255).astype(np.uint8))
    ax.axis('off')
plt.show()
