In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import time
import random
import copy
import json
import gc

# --- 1. Global Configuration ---
config = {
    "DEVICE": "cuda" if torch.cuda.is_available() else "cpu",
    "BATCH_SIZE": 128,
    "DATA_DIR": "/home/kami/Documents/datasets/",
    "NUM_CLASSES": 101,
    "EVOLUTION_POPULATION_SIZE": 10,
    "EVOLUTION_GENERATIONS": 5,
    "EVOLUTION_NUM_PARENTS": 3,
    "EVAL_EPOCHS": 1,
}

# --- 2. Modular AI Components ---

# ================================================================= #
# --- CORRECTED HELPER MODULES ---
class Permute(nn.Module):
    def __init__(self, *dims):
        super().__init__() # CRITICAL: Call parent constructor
        self.dims = dims
    def forward(self, x):
        return x.permute(*self.dims)

class Squeeze(nn.Module):
    def __init__(self, dim):
        super().__init__() # CRITICAL: Call parent constructor
        self.dim = dim
    def forward(self, x):
        return x.squeeze(self.dim)
# ================================================================= #


class CNNToTransformerAdapter(nn.Module):
    def __init__(self, cnn_out_channels, transformer_embed_dim):
        super().__init__()
        self.proj = nn.Conv2d(cnn_out_channels, transformer_embed_dim, kernel_size=1)
        self.flatten = nn.Flatten(2)
        self.norm = nn.LayerNorm(transformer_embed_dim)
    def forward(self, x):
        x = self.proj(x); x = self.flatten(x); x = x.permute(0, 2, 1); x = self.norm(x)
        return x

class AssembledBackbone(nn.Module):
    def __init__(self, genome, module_library):
        super().__init__()
        self.genome = genome
        self.module_list = nn.ModuleList()
        self.output_channels = 0

        current_channels = 3
        current_type = 'image'

        for module_name in genome:
            module_info = module_library[module_name]
            if module_info['input_type'] == 'sequence' and current_type == 'image':
                adapter = CNNToTransformerAdapter(current_channels, module_info['input_channels'])
                self.module_list.append(adapter)

            module_instance = copy.deepcopy(module_info['module'])
            self.module_list.append(module_instance)
            current_channels = module_info['output_channels']
            current_type = module_info['output_type']

        if current_type == 'image':
            self.module_list.append(nn.AdaptiveAvgPool2d((1, 1)))
            self.module_list.append(nn.Flatten(1))
            self.output_channels = current_channels
        else:
            self.module_list.append(Permute(0, 2, 1))
            self.module_list.append(nn.AdaptiveAvgPool1d(1))
            self.module_list.append(Squeeze(-1))
            self.output_channels = current_channels

    def forward(self, x):
        for layer in self.module_list:
            x = layer(x)
        return x

# --- 3. The Search System (Evolutionary Algorithm) ---
class NeuralArchitectureSearch:
    def __init__(self, search_space, module_library, dataset):
        self.search_space = search_space
        self.module_library = module_library
        self.dataset = dataset
        self.population = []
        self.history = []

    def _generate_random_genome(self):
        genome = []
        current_type = 'image'
        for stage in self.search_space:
            possible_modules = [
                m for m in stage if
                (self.module_library[m]['input_type'] == current_type) or
                (self.module_library[m]['input_type'] == 'sequence' and current_type == 'image')
            ]
            chosen_module = random.choice(possible_modules)
            genome.append(chosen_module)
            current_type = self.module_library[chosen_module]['output_type']
        return genome

    def _mutate_genome(self, genome):
        mutation_point = random.randint(0, len(genome) - 1)
        if mutation_point == 0: input_type_needed = 'image'
        else: input_type_needed = self.module_library[genome[mutation_point - 1]]['output_type']
        stage = self.search_space[mutation_point]
        compatible_replacements = [
            m for m in stage if
            (self.module_library[m]['input_type'] == input_type_needed) or
            (self.module_library[m]['input_type'] == 'sequence' and input_type_needed == 'image')
        ]
        new_module_name = random.choice(compatible_replacements)
        new_genome = copy.deepcopy(genome)
        new_genome[mutation_point] = new_module_name
        return new_genome

    def _evaluate_genome_with_caft(self, genome):
        print(f"\nEvaluating Genome: {' -> '.join(genome)}")

        backbone = AssembledBackbone(genome, self.module_library).to(config["DEVICE"])
        for param in backbone.parameters(): param.requires_grad = False
        backbone.eval()

        output_dim = backbone.output_channels
        classifier = nn.Linear(output_dim, config["NUM_CLASSES"]).to(config["DEVICE"])
        optimizer = optim.Adam(classifier.parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss()

        caching_start_time = time.time()
        caching_loader = DataLoader(self.dataset, batch_size=config["BATCH_SIZE"], shuffle=False)
        cached_features = torch.zeros(len(self.dataset), output_dim)

        with torch.no_grad():
            start_idx = 0
            for images, _ in caching_loader:
                features = backbone(images.to(config["DEVICE"]))
                cached_features[start_idx : start_idx + images.size(0)] = features.cpu()
                start_idx += images.size(0)
        caching_time = time.time() - caching_start_time
        cached_features = cached_features.to(config["DEVICE"])

        train_labels = [label for _, label in self.dataset]
        indexed_dataset = TensorDataset(torch.arange(len(self.dataset)), torch.tensor(train_labels))
        accelerated_loader = DataLoader(indexed_dataset, batch_size=config["BATCH_SIZE"], shuffle=True)

        classifier.train()
        for _ in range(config["EVAL_EPOCHS"]):
            for indices, labels in accelerated_loader:
                features = cached_features[indices.to(config["DEVICE"])]
                optimizer.zero_grad()
                outputs = classifier(features)
                loss = criterion(outputs, labels.to(config["DEVICE"]))
                loss.backward()
                optimizer.step()

        classifier.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for indices, labels in accelerated_loader:
                features = cached_features[indices.to(config["DEVICE"])]
                outputs = classifier(features)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels.to(config["DEVICE"])).sum().item()
        accuracy = 100 * correct / total

        fitness = accuracy / (caching_time + 1)

        print(f" -> Cache Time: {caching_time:.2f}s | Accuracy: {accuracy:.2f}% | Fitness: {fitness:.2f}")

        del backbone, classifier, cached_features
        gc.collect()
        torch.cuda.empty_cache()

        return {"genome": genome, "accuracy": accuracy, "latency": caching_time, "fitness": fitness}

    def run_search(self):
        print("--- Initializing Generation 0 ---")
        for _ in range(config["EVOLUTION_POPULATION_SIZE"]):
            self.population.append(self._generate_random_genome())

        for gen in range(config["EVOLUTION_GENERATIONS"]):
            print(f"\n\n{'='*20} EVOLUTION GENERATION {gen+1}/{config['EVOLUTION_GENERATIONS']} {'='*20}")

            evaluated_population = [self._evaluate_genome_with_caft(genome) for genome in self.population]
            evaluated_population.sort(key=lambda x: x['fitness'], reverse=True)
            self.history.append(evaluated_population)

            print(f"\n--- Top Model of Generation {gen+1} ---")
            best_model = evaluated_population[0]
            print(f"  Genome: {' -> '.join(best_model['genome'])}")
            print(f"  Accuracy: {best_model['accuracy']:.2f}% | Latency: {best_model['latency']:.2f}s | Fitness: {best_model['fitness']:.2f}")

            parents = [p['genome'] for p in evaluated_population[:config["EVOLUTION_NUM_PARENTS"]]]
            next_population = parents

            while len(next_population) < config["EVOLUTION_POPULATION_SIZE"]:
                parent = random.choice(parents)
                child = self._mutate_genome(parent)
                next_population.append(child)

            self.population = next_population

        print("\n\n--- Search Complete ---")
        return self.history[-1][0]


# --- Main Execution ---
if __name__ == "__main__":
    print("--- Building Module Library (this may take a moment) ---")
    resnet50 = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
    vit = models.vit_b_16(weights=models.ViT_B_16_Weights.DEFAULT)

    module_library = {
        'resnet_stem': {'module': nn.Sequential(resnet50.conv1, resnet50.bn1, resnet50.relu, resnet50.maxpool),
                        'input_type': 'image', 'output_type': 'image', 'input_channels': 3, 'output_channels': 64},
        'resnet_layer1': {'module': resnet50.layer1, 'input_type': 'image', 'output_type': 'image', 'input_channels': 64, 'output_channels': 256},
        'resnet_layer2': {'module': resnet50.layer2, 'input_type': 'image', 'output_type': 'image', 'input_channels': 256, 'output_channels': 512},
        'resnet_layer3': {'module': resnet50.layer3, 'input_type': 'image', 'output_type': 'image', 'input_channels': 512, 'output_channels': 1024},
        'transformer_encoder': {'module': vit.encoder.layers[0], 'input_type': 'sequence', 'output_type': 'sequence', 'input_channels': 768, 'output_channels': 768},
    }

    for key in module_library:
        for param in module_library[key]['module'].parameters():
            param.requires_grad = False

    search_space = [
        ['resnet_stem'],
        ['resnet_layer1'],
        ['resnet_layer2', 'resnet_layer3'],
        ['transformer_encoder'],
    ]

    print("\n--- Loading Food101 Dataset ---")
    data_transforms = transforms.Compose([
        transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    full_train_dataset = torchvision.datasets.Food101(root=config["DATA_DIR"], split='train', download=False, transform=data_transforms)
    subset_indices = random.sample(range(len(full_train_dataset)), 5000)
    train_dataset_subset = torch.utils.data.Subset(full_train_dataset, subset_indices)
    print(f"Using a subset of {len(train_dataset_subset)} images for the search.")

    nas = NeuralArchitectureSearch(search_space, module_library, train_dataset_subset)
    best_discovered_model = nas.run_search()

    print("\n\n" + "="*60)
    print("      BEST DISCOVERED ARCHITECTURE")
    print("="*60)
    print(f"  Genome: {' -> '.join(best_discovered_model['genome'])}")
    print(f"  Final Accuracy: {best_discovered_model['accuracy']:.2f}%")
    print(f"  Backbone Latency (Caching Time): {best_discovered_model['latency']:.2f}s")
    print(f"  Final Fitness Score: {best_discovered_model['fitness']:.2f}")

--- Building Module Library (this may take a moment) ---

--- Loading Food101 Dataset ---
Using a subset of 5000 images for the search.
--- Initializing Generation 0 ---



Evaluating Genome: resnet_stem -> resnet_layer1 -> resnet_layer2 -> transformer_encoder
 -> Cache Time: 16.34s | Accuracy: 5.28% | Fitness: 0.30

Evaluating Genome: resnet_stem -> resnet_layer1 -> resnet_layer2 -> transformer_encoder
 -> Cache Time: 16.76s | Accuracy: 6.36% | Fitness: 0.36

Evaluating Genome: resnet_stem -> resnet_layer1 -> resnet_layer2 -> transformer_encoder
 -> Cache Time: 15.94s | Accuracy: 3.96% | Fitness: 0.23

Evaluating Genome: resnet_stem -> resnet_layer1 -> resnet_layer3 -> transformer_encoder


RuntimeError: Given groups=1, weight of size [256, 512, 1, 1], expected input[128, 256, 56, 56] to have 512 channels, but got 256 channels instead