# CBAM-Powered Autoencoder for BERT Embedding Compression
This notebook compresses high-dimensional BERT embeddings using a convolutional autoencoder with CBAM attention modules.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import pickle
import numpy as np
from tqdm import tqdm


In [None]:
# Load BERT embeddings
with open("bert_embeddings.pkl", "rb") as f:
    data = pickle.load(f)
    X_train = np.array(data['train_embeddings'])
    y_train = np.array(data['y_train'])

# Reshape and transpose for CNN input
X_train = X_train.reshape((-1, 64, 768, 4)).transpose(0, 3, 1, 2)
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)

In [None]:
# Define CBAM modules
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=8):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        return self.sigmoid(avg_out + max_out)

class SpatialAttention(nn.Module):
    def __init__(self):
        super(SpatialAttention, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        return self.sigmoid(self.conv(x))

class CBAM(nn.Module):
    def __init__(self, in_planes):
        super(CBAM, self).__init__()
        self.ca = ChannelAttention(in_planes)
        self.sa = SpatialAttention()
    def forward(self, x):
        x = x * self.ca(x)
        x = x * self.sa(x)
        return x

In [None]:
# Define CBAM Autoencoder
class CBAMAutoencoder(nn.Module):
    def __init__(self):
        super(CBAMAutoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(4, 8, 3, stride=2, padding=1), nn.BatchNorm2d(8), nn.ReLU(), CBAM(8),
            nn.Conv2d(8, 16, 3, stride=2, padding=1), nn.BatchNorm2d(16), nn.ReLU(), CBAM(16),
            nn.Conv2d(16, 32, 3, stride=2, padding=1), nn.BatchNorm2d(32), nn.ReLU(), CBAM(32),
            nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.BatchNorm2d(64), nn.ReLU(), CBAM(64)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1), nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1), nn.ReLU(),
            nn.ConvTranspose2d(16, 8, 3, stride=2, padding=1, output_padding=1), nn.ReLU(),
            nn.ConvTranspose2d(8, 4, 3, stride=2, padding=1, output_padding=1), nn.Sigmoid()
        )
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded, encoded

In [None]:
# Train the autoencoder
model = CBAMAutoencoder()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loader = DataLoader(TensorDataset(X_train_tensor, X_train_tensor), batch_size=8, shuffle=True)

model.train()
for epoch in range(10):
    total_loss = 0
    for x_batch, y_batch in loader:
        optimizer.zero_grad()
        decoded, _ = model(x_batch)
        loss = criterion(decoded, y_batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss / len(loader):.4f}")

In [None]:
# Extract latent features
model.eval()
with torch.no_grad():
    _, latent_features = model(X_train_tensor)
compressed_features = latent_features.view(latent_features.size(0), -1).numpy()

# Save compressed features
with open("cbam_compressed_features.pkl", "wb") as f:
    pickle.dump({"compressed_features": compressed_features, "y_train": y_train}, f)