In [None]:
import os
from pathlib import Path
from PIL import Image
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image

torch.manual_seed(123)
# ------------------------------------------------
# 1. Load Dataset
# ------------------------------------------------
train_dir = "data/EuroSAT/train"
test_dir = "data/EuroSAT/test"

transform = transforms.Compose([
    #transforms.Resize((256,256)),
    transforms.ToTensor(),
])

train_dataset = datasets.ImageFolder(train_dir, transform=transform)
test_dataset  = datasets.ImageFolder(test_dir,  transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=64, shuffle=False)

# ------------------------------------------------
# 2. Autoencoder Definition (16x16 bottleneck)
# ------------------------------------------------
class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()

        # Encoder → compress to 16×16
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),  # 64→32
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 32→16
            nn.ReLU(),
        )

        # Decoder → reconstruct 64×64
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # 16→32
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),  # 32→64
            nn.Sigmoid(),
        )

    def forward(self, x):
        z = self.encoder(x)
        out = self.decoder(z)
        return out, z

# ------------------------------------------------
# 3. Train the Autoencoder
# ------------------------------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Autoencoder().to(device)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

epochs = 150
model.train()

for epoch in range(epochs):
    total_loss = 0
    for imgs, _ in train_loader:
        imgs = imgs.to(device)

        optimizer.zero_grad()
        outputs, _ = model(imgs)
        loss = criterion(outputs, imgs)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{epochs}, Loss={total_loss/len(train_loader):.4f}")

# ------------------------------------------------
# 4. Save Encoded 16×16 Representations
# ------------------------------------------------
output_root = Path("new_dataset")
(output_root / "train").mkdir(parents=True, exist_ok=True)
(output_root / "test").mkdir(parents=True, exist_ok=True)

model.eval()

def save_encoded_split(dataloader, split_name, dataset):
    for img_tensor, labels in dataloader:
        img_tensor = img_tensor.to(device)
        _, encoded = model(img_tensor)   # shape: (B, 32, 16, 16)

        # Convert 32-channel bottleneck → single-channel 16×16
        encoded_img = encoded.mean(dim=1, keepdim=True)  # (B, 1, 16, 16)

        for i in range(encoded_img.size(0)):
            class_id = labels[i]
            class_name = dataset.classes[class_id]

            out_dir = output_root / split_name / class_name
            out_dir.mkdir(parents=True, exist_ok=True)

            filename = f"{len(os.listdir(out_dir))}.png"
            save_image(encoded_img[i], out_dir / filename)

save_encoded_split(train_loader, "train", train_dataset)
save_encoded_split(test_loader,  "test",  test_dataset)

print("Done! Encoded images saved in new_dataset/")

import matplotlib.pyplot as plt
import torch

# Get one batch
images, _ = next(iter(train_loader))
images = images.to(device)

# Pass through autoencoder
model.eval()
with torch.no_grad():
    recon, encoded = model(images)

# Display sample images
for idx in [15, 34, 17, 11, 0]:
    original = images[idx].cpu().permute(1, 2, 0)         # (H,W,C)
    encoded_img = encoded[idx].mean(dim=0).cpu()          # (16,16)
    recon_img = recon[idx].mean(dim=0).cpu()          # (16,16)

    # Plot
    plt.figure(figsize=(8,4))

    plt.subplot(1,3,1)
    plt.title("Original 64×64")
    plt.imshow(original)
    plt.axis("off")

    plt.subplot(1,3,2)
    plt.title("Encoded 16×16")
    plt.imshow(encoded_img, cmap='grey')
    plt.axis("off")

    plt.subplot(1,3,3)
    plt.title("Decoded 64×64")
    plt.imshow(recon_img)
    plt.axis("off")

    plt.show()
