In [8]:
import os
import glob
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import cv2
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

# --- Optimized VAE Architecture ---
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import cv2
import numpy as np
import os
import glob
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
# --- Dataset Class ---
class DrawingDataset(Dataset):
    def __init__(self, data_dir):
        # Find all png files in the episode folders
        self.img_paths = glob.glob(os.path.join(data_dir, "ep_*", "*.png"))
        
        # Define the transformations
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((64, 64)), 
            transforms.ToTensor(),
        ])

    def __len__(self):
        # The DataLoader calls len(dataset) to know how many images exist
        return len(self.img_paths)

    def __getitem__(self, idx):
        # Load the image in grayscale
        img_path = self.img_paths[idx]
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        
        # Apply transforms and return
        return self.transform(img)

In [None]:
class VAE64(nn.Module):
    def __init__(self, latent_dim=128):
        super(VAE64, self).__init__()
        
        # Encoder: 64x64 -> 32x32 -> 16x16 -> 8x8
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 4, stride=2, padding=1), 
            nn.BatchNorm2d(32),
            nn.GELU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1), 
            nn.BatchNorm2d(64),
            nn.GELU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1), 
            nn.BatchNorm2d(128),
            nn.GELU(),
            nn.Flatten(),
            nn.Linear(128 * 8 * 8, 512),
            nn.GELU()
        )
        
        self.fc_mu = nn.Linear(512, latent_dim)
        self.fc_logvar = nn.Linear(512, latent_dim)
        
        # Decoder: 1x1 -> 8x8 -> 16x16 -> 32x32 -> 64x64
        self.decoder_input = nn.Linear(latent_dim, 512)
        self.decoder = nn.Sequential(
            nn.Unflatten(1, (512, 1, 1)),
            nn.ConvTranspose2d(512, 128, 8, stride=1, padding=0), # 8x8
            nn.BatchNorm2d(128),
            nn.GELU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # 16x16
            nn.BatchNorm2d(64),
            nn.GELU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),   # 32x32
            nn.BatchNorm2d(32),
            nn.GELU(),
            nn.ConvTranspose2d(32, 1, 4, stride=2, padding=1),    # 64x64
            nn.Sigmoid() 
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = self.decoder_input(z)
        return self.decoder(h)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [None]:
# --- Adjusted Loss Function ---
def weighted_vae_loss(recon_x, x, mu, logvar, epoch):
    # 1. Create the weight mask
    # We give the white pixels (the lines) 50x more weight than the background
    weight = torch.ones_like(x)
    weight[x > 0.5] = 50.0 
    
    # 2. Weighted MSE
    # (recon_x - x)^2 * weight
    mse = torch.sum(weight * (recon_x - x) ** 2)
    
    # 3. KL Divergence (Extreme Beta Warmup)
    # Start Beta at 0 and let it crawl up very slowly after epoch 50
    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    beta = 0 if epoch < 50 else min(0.0001, (epoch - 50) * 0.000002)
    
    return mse + (beta * kld)

In [None]:
# --- Training Engine ---
class VAETrainer:
    def __init__(self, model, device, lr=5e-5): # Lower learning rate for precision
        self.model = model.to(device)
        self.device = device
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)

    def train_epoch(self, loader, epoch):
        self.model.train()
        train_loss = 0
        for batch in tqdm(loader, desc=f"Epoch {epoch}"):
            batch = batch.to(self.device)
            self.optimizer.zero_grad()
            
            recon_batch, mu, logvar = self.model(batch)
            loss = weighted_vae_loss(recon_batch, batch, mu, logvar, epoch)
            
            loss.backward()
            train_loss += loss.item()
            self.optimizer.step()
        
        return train_loss / len(loader.dataset)

    def visualize_results(self, epoch, test_batch):
        self.model.eval()
        with torch.no_grad():
            test_batch = test_batch.to(self.device)
            recon, _, _ = self.model(test_batch)
            orig = test_batch[0].cpu().squeeze().numpy()
            rec = recon[0].cpu().squeeze().numpy()
            
            combined = np.hstack((orig, rec))
            plt.imshow(combined, cmap='gray')
            plt.title(f"Epoch {epoch}: Target vs Prediction")
            plt.axis('off')
            plt.savefig(f"epoch_{epoch}_recon.png")
            plt.close()

In [None]:

# --- Execution ---
if __name__ == "__main__":
    LATENT_DIM = 128
    BATCH_SIZE = 64 # Increased batch size for more stable KLD
    EPOCHS = 100    # Accurate VAEs need more epochs for high resolution
    DATA_DIR = "drawing_data"
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset = DrawingDataset(DATA_DIR)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    
    model = VAE64(latent_dim=LATENT_DIM)
    trainer = VAETrainer(model, DEVICE)

    for epoch in range(1, EPOCHS + 1):
        avg_loss = trainer.train_epoch(loader, epoch)
        print(f"Loss: {avg_loss:.4f}")
        
        if epoch % 5 == 0:
            sample_batch = next(iter(loader))
            trainer.visualize_results(epoch, sample_batch)

    torch.save(model.state_dict(), "vae_vla_base.pth")
    print("Accurate Model Saved.")

Epoch 1:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 1: 100%|██████████| 13/13 [00:05<00:00,  2.17it/s]


Loss: 914.2000


Epoch 2: 100%|██████████| 13/13 [00:05<00:00,  2.18it/s]


Loss: 802.4114


Epoch 3: 100%|██████████| 13/13 [00:05<00:00,  2.28it/s]


Loss: 692.3802


Epoch 4: 100%|██████████| 13/13 [00:05<00:00,  2.50it/s]


Loss: 581.2556


Epoch 5: 100%|██████████| 13/13 [00:05<00:00,  2.41it/s]


Loss: 503.4939


Epoch 6: 100%|██████████| 13/13 [00:05<00:00,  2.41it/s]


Loss: 453.0713


Epoch 7: 100%|██████████| 13/13 [00:05<00:00,  2.50it/s]


Loss: 415.0368


Epoch 8: 100%|██████████| 13/13 [00:05<00:00,  2.42it/s]


Loss: 383.2708


Epoch 9: 100%|██████████| 13/13 [00:05<00:00,  2.28it/s]


Loss: 356.3653


Epoch 10: 100%|██████████| 13/13 [00:06<00:00,  2.16it/s]


Loss: 332.8153


Epoch 11: 100%|██████████| 13/13 [00:05<00:00,  2.41it/s]


Loss: 312.9192


Epoch 12: 100%|██████████| 13/13 [00:05<00:00,  2.27it/s]


Loss: 294.6058


Epoch 13: 100%|██████████| 13/13 [00:05<00:00,  2.29it/s]


Loss: 278.9884


Epoch 14: 100%|██████████| 13/13 [00:05<00:00,  2.48it/s]


Loss: 265.3638


Epoch 15: 100%|██████████| 13/13 [00:05<00:00,  2.27it/s]


Loss: 252.5307


Epoch 16: 100%|██████████| 13/13 [00:05<00:00,  2.42it/s]


Loss: 241.2756


Epoch 17: 100%|██████████| 13/13 [00:05<00:00,  2.48it/s]


Loss: 230.8385


Epoch 18: 100%|██████████| 13/13 [00:05<00:00,  2.31it/s]


Loss: 221.1671


Epoch 19: 100%|██████████| 13/13 [00:06<00:00,  2.09it/s]


Loss: 212.4659


Epoch 20: 100%|██████████| 13/13 [00:05<00:00,  2.22it/s]


Loss: 203.8417


Epoch 21: 100%|██████████| 13/13 [00:05<00:00,  2.39it/s]


Loss: 196.0178


Epoch 22: 100%|██████████| 13/13 [00:05<00:00,  2.54it/s]


Loss: 188.4332


Epoch 23: 100%|██████████| 13/13 [00:05<00:00,  2.32it/s]


Loss: 181.4138


Epoch 24: 100%|██████████| 13/13 [00:05<00:00,  2.30it/s]


Loss: 175.1088


Epoch 25: 100%|██████████| 13/13 [00:04<00:00,  2.62it/s]


Loss: 169.0128


Epoch 26: 100%|██████████| 13/13 [00:05<00:00,  2.49it/s]


Loss: 163.2348


Epoch 27: 100%|██████████| 13/13 [00:05<00:00,  2.37it/s]


Loss: 158.6077


Epoch 28: 100%|██████████| 13/13 [00:06<00:00,  2.05it/s]


Loss: 152.9357


Epoch 29: 100%|██████████| 13/13 [00:06<00:00,  2.09it/s]


Loss: 147.8609


Epoch 30: 100%|██████████| 13/13 [00:05<00:00,  2.19it/s]


Loss: 143.3774


Epoch 31: 100%|██████████| 13/13 [00:05<00:00,  2.38it/s]


Loss: 139.2126


Epoch 32: 100%|██████████| 13/13 [00:04<00:00,  2.65it/s]


Loss: 135.4408


Epoch 33: 100%|██████████| 13/13 [00:05<00:00,  2.39it/s]


Loss: 132.3613


Epoch 34: 100%|██████████| 13/13 [00:04<00:00,  2.73it/s]


Loss: 129.1459


Epoch 35: 100%|██████████| 13/13 [00:04<00:00,  2.77it/s]


Loss: 124.9149


Epoch 36: 100%|██████████| 13/13 [00:04<00:00,  2.84it/s]


Loss: 121.4096


Epoch 37: 100%|██████████| 13/13 [00:04<00:00,  2.79it/s]


Loss: 118.4301


Epoch 38: 100%|██████████| 13/13 [00:04<00:00,  2.72it/s]


Loss: 115.4670


Epoch 39: 100%|██████████| 13/13 [00:04<00:00,  2.72it/s]


Loss: 112.6502


Epoch 40: 100%|██████████| 13/13 [00:04<00:00,  2.73it/s]


Loss: 109.4658


Epoch 41: 100%|██████████| 13/13 [00:04<00:00,  2.76it/s]


Loss: 106.8719


Epoch 42: 100%|██████████| 13/13 [00:04<00:00,  2.72it/s]


Loss: 104.4658


Epoch 43: 100%|██████████| 13/13 [00:04<00:00,  2.77it/s]


Loss: 102.0863


Epoch 44: 100%|██████████| 13/13 [00:04<00:00,  2.83it/s]


Loss: 99.7697


Epoch 45: 100%|██████████| 13/13 [00:04<00:00,  2.83it/s]


Loss: 97.3372


Epoch 46: 100%|██████████| 13/13 [00:04<00:00,  2.63it/s]


Loss: 94.9930


Epoch 47: 100%|██████████| 13/13 [00:04<00:00,  2.65it/s]


Loss: 92.6802


Epoch 48: 100%|██████████| 13/13 [00:04<00:00,  2.66it/s]


Loss: 90.9135


Epoch 49: 100%|██████████| 13/13 [00:04<00:00,  2.76it/s]


Loss: 88.8131


Epoch 50: 100%|██████████| 13/13 [00:04<00:00,  2.77it/s]


Loss: 86.6212


Epoch 51: 100%|██████████| 13/13 [00:04<00:00,  2.75it/s]


Loss: 84.8858


Epoch 52: 100%|██████████| 13/13 [00:04<00:00,  2.75it/s]


Loss: 83.2059


Epoch 53: 100%|██████████| 13/13 [00:04<00:00,  2.77it/s]


Loss: 80.9173


Epoch 54: 100%|██████████| 13/13 [00:04<00:00,  2.75it/s]


Loss: 79.0675


Epoch 55: 100%|██████████| 13/13 [00:04<00:00,  2.68it/s]


Loss: 77.2352


Epoch 56: 100%|██████████| 13/13 [00:04<00:00,  2.77it/s]


Loss: 75.5444


Epoch 57: 100%|██████████| 13/13 [00:04<00:00,  2.85it/s]


Loss: 73.8825


Epoch 58: 100%|██████████| 13/13 [00:04<00:00,  2.84it/s]


Loss: 72.2578


Epoch 59: 100%|██████████| 13/13 [00:04<00:00,  2.72it/s]


Loss: 70.4959


Epoch 60: 100%|██████████| 13/13 [00:04<00:00,  2.72it/s]


Loss: 69.0013


Epoch 61: 100%|██████████| 13/13 [00:04<00:00,  2.82it/s]


Loss: 67.6146


Epoch 62: 100%|██████████| 13/13 [00:04<00:00,  2.78it/s]


Loss: 66.0399


Epoch 63: 100%|██████████| 13/13 [00:04<00:00,  2.68it/s]


Loss: 64.7115


Epoch 64: 100%|██████████| 13/13 [00:04<00:00,  2.71it/s]


Loss: 63.5638


Epoch 65: 100%|██████████| 13/13 [00:04<00:00,  2.72it/s]


Loss: 62.1958


Epoch 66: 100%|██████████| 13/13 [00:04<00:00,  2.70it/s]


Loss: 61.1333


Epoch 67: 100%|██████████| 13/13 [00:04<00:00,  2.74it/s]


Loss: 59.7076


Epoch 68: 100%|██████████| 13/13 [00:04<00:00,  2.69it/s]


Loss: 58.5021


Epoch 69: 100%|██████████| 13/13 [00:04<00:00,  2.65it/s]


Loss: 57.3331


Epoch 70: 100%|██████████| 13/13 [00:04<00:00,  2.85it/s]


Loss: 56.1498


Epoch 71: 100%|██████████| 13/13 [00:05<00:00,  2.57it/s]


Loss: 55.1987


Epoch 72: 100%|██████████| 13/13 [00:04<00:00,  2.79it/s]


Loss: 54.3519


Epoch 73: 100%|██████████| 13/13 [00:04<00:00,  2.69it/s]


Loss: 53.0986


Epoch 74: 100%|██████████| 13/13 [00:04<00:00,  2.81it/s]


Loss: 51.9272


Epoch 75: 100%|██████████| 13/13 [00:04<00:00,  2.76it/s]


Loss: 50.7999


Epoch 76: 100%|██████████| 13/13 [00:04<00:00,  2.77it/s]


Loss: 49.6789


Epoch 77: 100%|██████████| 13/13 [00:04<00:00,  2.73it/s]


Loss: 48.8512


Epoch 78: 100%|██████████| 13/13 [00:04<00:00,  2.80it/s]


Loss: 47.9269


Epoch 79: 100%|██████████| 13/13 [00:04<00:00,  2.71it/s]


Loss: 47.1016


Epoch 80: 100%|██████████| 13/13 [00:04<00:00,  2.71it/s]


Loss: 46.3180


Epoch 81: 100%|██████████| 13/13 [00:04<00:00,  2.70it/s]


Loss: 45.3794


Epoch 82: 100%|██████████| 13/13 [00:04<00:00,  2.85it/s]


Loss: 44.9116


Epoch 83: 100%|██████████| 13/13 [00:04<00:00,  2.86it/s]


Loss: 44.2618


Epoch 84: 100%|██████████| 13/13 [00:04<00:00,  2.62it/s]


Loss: 43.3557


Epoch 85: 100%|██████████| 13/13 [00:05<00:00,  2.58it/s]


Loss: 42.3370


Epoch 86: 100%|██████████| 13/13 [00:04<00:00,  2.64it/s]


Loss: 41.5683


Epoch 87: 100%|██████████| 13/13 [00:04<00:00,  2.77it/s]


Loss: 40.8893


Epoch 88: 100%|██████████| 13/13 [00:04<00:00,  2.69it/s]


Loss: 40.2781


Epoch 89: 100%|██████████| 13/13 [00:04<00:00,  2.62it/s]


Loss: 39.4050


Epoch 90: 100%|██████████| 13/13 [00:04<00:00,  2.81it/s]


Loss: 38.7324


Epoch 91: 100%|██████████| 13/13 [00:04<00:00,  2.76it/s]


Loss: 38.0343


Epoch 92: 100%|██████████| 13/13 [00:04<00:00,  2.70it/s]


Loss: 37.4116


Epoch 93: 100%|██████████| 13/13 [00:04<00:00,  2.78it/s]


Loss: 36.8398


Epoch 94: 100%|██████████| 13/13 [00:04<00:00,  2.78it/s]


Loss: 36.2394


Epoch 95: 100%|██████████| 13/13 [00:04<00:00,  2.71it/s]


Loss: 35.6184


Epoch 96: 100%|██████████| 13/13 [00:04<00:00,  2.79it/s]


Loss: 35.0819


Epoch 97: 100%|██████████| 13/13 [00:04<00:00,  2.73it/s]


Loss: 34.4583


Epoch 98: 100%|██████████| 13/13 [00:04<00:00,  2.75it/s]


Loss: 34.0588


Epoch 99: 100%|██████████| 13/13 [00:04<00:00,  2.75it/s]


Loss: 33.4995


Epoch 100: 100%|██████████| 13/13 [00:04<00:00,  2.74it/s]


Loss: 32.3732
Accurate Model Saved.
