In [1]:
# create_module_library.py
import torch
import torchvision.models as models
import os

print("Creating the pre-trained submodule library...")
os.makedirs("module_library", exist_ok=True)
resnet50 = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
resnet_stem = torch.nn.Sequential(resnet50.conv1, resnet50.bn1, resnet50.relu, resnet50.maxpool)
torch.save(resnet_stem.state_dict(), "module_library/resnet_stem.pth")
torch.save(resnet50.layer1.state_dict(), "module_library/resnet_layer1.pth")
torch.save(resnet50.layer2.state_dict(), "module_library/resnet_layer2.pth")
vit = models.vit_b_16(weights=models.ViT_B_16_Weights.DEFAULT)
transformer_block = vit.encoder.layers[0]
torch.save(transformer_block.state_dict(), "module_library/transformer_encoder_block.pth")
print("Module library created successfully!")

Creating the pre-trained submodule library...
Module library created successfully!


In [3]:
# main_modular_caft.py
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import time

# --- 1. Global Configuration ---
config = {
    "DEVICE": "cuda" if torch.cuda.is_available() else "cpu",
    "BATCH_SIZE": 64,
    "NUM_EPOCHS": 100,
    "LEARNING_RATE": 0.001,
    "DATA_DIR": "/home/kami/Documents/datasets/",
    "NUM_CLASSES": 101,
}

# --- 2. Adapter Module (The "Glue") ---
class CNNToTransformerAdapter(nn.Module):
    def __init__(self, cnn_out_channels, transformer_embed_dim):
        super().__init__()
        self.proj = nn.Conv2d(cnn_out_channels, transformer_embed_dim, kernel_size=1)
        self.flatten = nn.Flatten(2)
        self.norm = nn.LayerNorm(transformer_embed_dim)
    def forward(self, x):
        x = self.proj(x); x = self.flatten(x); x = x.permute(0, 2, 1); x = self.norm(x)
        return x

# --- 3. The Custom Modular Backbone to be Cached ---
class ModularBackbone(nn.Module):
    def __init__(self):
        super().__init__()
        print("Assembling new ModularBackbone for caching...")
        # Load pre-trained CNN parts
        resnet50 = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        self.stem = torch.nn.Sequential(resnet50.conv1, resnet50.bn1, resnet50.relu, resnet50.maxpool)
        self.stem.load_state_dict(torch.load("module_library/resnet_stem.pth"))
        self.layer1 = resnet50.layer1
        self.layer1.load_state_dict(torch.load("module_library/resnet_layer1.pth"))
        self.layer2 = resnet50.layer2
        self.layer2.load_state_dict(torch.load("module_library/resnet_layer2.pth"))
        # Load pre-trained Transformer part
        vit = models.vit_b_16(weights=models.ViT_B_16_Weights.DEFAULT)
        self.transformer = vit.encoder.layers[0]
        self.transformer.load_state_dict(torch.load("module_library/transformer_encoder_block.pth"))
        # Instantiate adapter
        self.adapter = CNNToTransformerAdapter(cnn_out_channels=512, transformer_embed_dim=768)
        # Final pooling layer to create a single feature vector per image
        self.pool = nn.AdaptiveAvgPool1d(1)

    def forward(self, x):
        # The full flow of our custom feature extractor
        x = self.stem(x)    # CNN part
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.adapter(x) # Glue
        x = self.transformer(x) # Transformer part
        # Pool the sequence of features into a single vector
        x = x.permute(0, 2, 1) # [B, L, D] -> [B, D, L] for pooling
        x = self.pool(x).squeeze(2) # [B, D, 1] -> [B, D]
        return x

# --- Main Execution ---
if __name__ == "__main__":
    print(f"Using device: {config['DEVICE']}")
    # --- Shared Data Loading ---
    print("Loading Food101 dataset...")
    data_transforms = transforms.Compose([
        transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    train_dataset = torchvision.datasets.Food101(root=config["DATA_DIR"], split='train', download=False, transform=data_transforms)
    num_train_records = len(train_dataset)

    # ========================================================================
    # --- MODULAR CAFT EXPERIMENT ---
    # ========================================================================
    print("\n\n" + "="*60)
    print("        EXPERIMENT: CAFT with a Custom Modular Backbone")
    print("="*60)

    # 1. Instantiate the custom backbone and the separate trainable classifier
    modular_backbone = ModularBackbone().to(config['DEVICE'])
    for param in modular_backbone.parameters(): # Freeze the entire custom backbone
        param.requires_grad = False

    # The classifier head is small and trainable
    # It takes the output of our backbone (768 dims)
    classifier_head = nn.Linear(768, config["NUM_CLASSES"]).to(config["DEVICE"])

    # 2. Phase 1: Caching with the Modular Backbone
    caching_start_time = time.time()
    print("\n--- Phase 1: Caching Activations from Modular Backbone ---")

    caching_loader = DataLoader(train_dataset, batch_size=config["BATCH_SIZE"], shuffle=False, num_workers=4)
    cached_features = torch.zeros(num_train_records, 768) # Cache size matches backbone output

    modular_backbone.eval()
    with torch.no_grad():
        start_idx = 0
        for images, _ in tqdm(caching_loader, desc="Caching Progress"):
            images = images.to(config["DEVICE"])
            # Get features from our custom backbone
            features = modular_backbone(images)
            cached_features[start_idx : start_idx + images.size(0)] = features.cpu()
            start_idx += images.size(0)

    caching_time = time.time() - caching_start_time
    cached_features = cached_features.to(config["DEVICE"])

    # 3. Phase 2: Accelerated Training
    print("\n--- Phase 2: Accelerated Training of Classifier Head ---")
    train_labels = [label for _, label in train_dataset]
    indexed_dataset = TensorDataset(torch.arange(num_train_records), torch.tensor(train_labels))
    accelerated_loader = DataLoader(indexed_dataset, batch_size=config["BATCH_SIZE"], shuffle=True)

    optimizer = optim.Adam(classifier_head.parameters(), lr=config["LEARNING_RATE"])
    criterion = nn.CrossEntropyLoss()

    training_start_time = time.time()
    classifier_head.train()
    for epoch in range(config["NUM_EPOCHS"]):
        pbar = tqdm(accelerated_loader, desc=f"Modular CAFT Epoch {epoch+1}/{config['NUM_EPOCHS']}")
        correct, total = 0, 0
        for indices, labels in pbar:
            indices, labels = indices.to(config["DEVICE"]), labels.to(config["DEVICE"])
            # Get features directly from the cache
            features = cached_features[indices]

            optimizer.zero_grad()
            outputs = classifier_head(features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            if epoch == config["NUM_EPOCHS"] - 1: # Accumulate stats on the final epoch
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

    training_time = time.time() - training_start_time
    final_accuracy = 100 * correct / total

    # 4. Final Report
    print("\n\n" + "="*60)
    print("--- Modular CAFT Results ---")
    print(f"Backbone: Custom Hybrid (ResNet + Adapter + Transformer)")
    print("-------------------------------------------------")
    print(f"Caching Phase Time:           {caching_time:.2f} seconds")
    print(f"Accelerated Training Time:    {training_time:.2f} seconds (for {config['NUM_EPOCHS']} epochs)")
    print(f"Total Time Elapsed:           {caching_time + training_time:.2f} seconds")
    print(f"Final Training Accuracy:      {final_accuracy:.2f}%")
    print("="*60)

Using device: cuda
Loading Food101 dataset...


        EXPERIMENT: CAFT with a Custom Modular Backbone
Assembling new ModularBackbone for caching...

--- Phase 1: Caching Activations from Modular Backbone ---


Caching Progress: 100%|██████████| 1184/1184 [01:49<00:00, 10.82it/s]



--- Phase 2: Accelerated Training of Classifier Head ---


Modular CAFT Epoch 1/100: 100%|██████████| 1184/1184 [00:00<00:00, 2432.90it/s]
Modular CAFT Epoch 2/100: 100%|██████████| 1184/1184 [00:00<00:00, 2469.62it/s]
Modular CAFT Epoch 3/100: 100%|██████████| 1184/1184 [00:00<00:00, 2315.08it/s]
Modular CAFT Epoch 4/100: 100%|██████████| 1184/1184 [00:00<00:00, 2443.44it/s]
Modular CAFT Epoch 5/100: 100%|██████████| 1184/1184 [00:00<00:00, 2210.98it/s]
Modular CAFT Epoch 6/100: 100%|██████████| 1184/1184 [00:00<00:00, 2447.76it/s]
Modular CAFT Epoch 7/100: 100%|██████████| 1184/1184 [00:00<00:00, 2375.70it/s]
Modular CAFT Epoch 8/100: 100%|██████████| 1184/1184 [00:00<00:00, 2313.89it/s]
Modular CAFT Epoch 9/100: 100%|██████████| 1184/1184 [00:00<00:00, 2137.87it/s]
Modular CAFT Epoch 10/100: 100%|██████████| 1184/1184 [00:00<00:00, 2439.09it/s]
Modular CAFT Epoch 11/100: 100%|██████████| 1184/1184 [00:00<00:00, 2309.99it/s]
Modular CAFT Epoch 12/100: 100%|██████████| 1184/1184 [00:00<00:00, 2374.87it/s]
Modular CAFT Epoch 13/100: 100%|█████



--- Modular CAFT Results ---
Backbone: Custom Hybrid (ResNet + Adapter + Transformer)
-------------------------------------------------
Caching Phase Time:           109.43 seconds
Accelerated Training Time:    50.59 seconds (for 100 epochs)
Total Time Elapsed:           160.02 seconds
Final Training Accuracy:      52.03%



