In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import numpy as np
from PIL import Image
import os
import cv2
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import time
import json
from pathlib import Path

try:
    import wandb
except ImportError:
    !pip install wandb -q
    import wandb

In [None]:
os.environ['WANDB_API_KEY'] = 'c1a8356072f55deda375cb8d821628d3b6962f9a'
wandb.login()

In [None]:
class Encoder(nn.Module):
    
    def __init__(self, latent_dim=128, input_channels=3, image_size=256):
        super(Encoder, self).__init__()
        
        self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=4, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        
        self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        
        self.conv4 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        
        self.conv5 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)
        self.bn5 = nn.BatchNorm2d(512)
        
        final_size = image_size // (2**5)
        self.flatten_size = 512 * final_size * final_size
        
        self.fc_mu = nn.Linear(self.flatten_size, latent_dim)
        self.fc_log_var = nn.Linear(self.flatten_size, latent_dim)
        
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.relu(self.bn5(self.conv5(x)))
        
        x = x.view(x.size(0), -1)
        
        mu = self.fc_mu(x)
        log_var = self.fc_log_var(x)
        
        return mu, log_var


In [None]:
class Decoder(nn.Module):
    
    def __init__(self, latent_dim=128, output_channels=3, image_size=256):
        super(Decoder, self).__init__()
        
        self.init_size = image_size // (2**5)
        self.init_channels = 512
        
        self.fc = nn.Linear(latent_dim, self.init_channels * self.init_size * self.init_size)
        
        self.deconv1 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(256)
        
        self.deconv2 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        
        self.deconv3 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        
        self.deconv4 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)
        self.bn4 = nn.BatchNorm2d(32)
        
        self.deconv5 = nn.ConvTranspose2d(32, output_channels, kernel_size=4, stride=2, padding=1)
        
    def forward(self, z):
        x = self.fc(z)
        x = x.view(x.size(0), self.init_channels, self.init_size, self.init_size)
        
        x = F.relu(self.bn1(self.deconv1(x)))
        x = F.relu(self.bn2(self.deconv2(x)))
        x = F.relu(self.bn3(self.deconv3(x)))
        x = F.relu(self.bn4(self.deconv4(x)))
        x = torch.sigmoid(self.deconv5(x))
        
        return x

In [None]:
class VAE(nn.Module):
    
    def __init__(self, latent_dim=128, input_channels=3, image_size=256):
        super(VAE, self).__init__()
        
        self.latent_dim = latent_dim
        self.encoder = Encoder(latent_dim, input_channels, image_size)
        self.decoder = Decoder(latent_dim, input_channels, image_size)
        
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x):
        mu, log_var = self.encoder(x)
        z = self.reparameterize(mu, log_var)
        reconstruction = self.decoder(z)
        return reconstruction, mu, log_var


def vae_loss(reconstruction, target, mu, log_var, beta=1.0):
    
    recon_loss = F.mse_loss(reconstruction, target, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    total_loss = recon_loss + beta * kl_loss
    return total_loss, recon_loss, kl_loss

In [None]:
class MotionDataset(Dataset):
    def __init__(self, split_dir, transform=None, task="autoencoder"):
        self.dir = split_dir
        self.transform = transform
        self.task = task
        self.files = sorted(f for f in os.listdir(self.dir)
                            if f.lower().endswith((".png",".jpg",".jpeg")))
        assert self.files, f"No images in {self.dir}"

    def __getitem__(self, i):
        path = os.path.join(self.dir, self.files[i])
        img = Image.open(path).convert("RGB")
        if self.transform: img = self.transform(img)

        if self.task == "autoencoder":
            inp = tgt = img
        elif self.task == "denoise":
            noisy = img + 0.05*torch.randn_like(img)   # example
            noisy = torch.clamp(noisy, 0, 1)
            inp, tgt = noisy, img
        else:
            # extend for other tasks (masking, SR, etc.)
            inp, tgt = img, img

        return inp, tgt

    def __len__(self): return len(self.files)

In [None]:
def create_data_loaders(data_root, batch_size=32, image_size=256, num_workers=2):
    
    
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
    ])
    
    
    print("\nTrain set:")
    #train_dataset = MotionDataset(os.path.join(data_root, 'train/input'),os.path.join(data_root, 'train/target'),transform=transform)
    train_dataset = MotionDataset(os.path.join(data_root, 'train'),transform=transform)
    
    print("\nValidation set:")
    val_dataset = MotionDataset(os.path.join(data_root, 'val'), transform=transform)
    
    print("\nTest set:")
    test_dataset = MotionDataset(os.path.join(data_root, 'test'), transform=transform)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, 
                             shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, 
                           shuffle=False, num_workers=num_workers)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, 
                            shuffle=False, num_workers=num_workers)
    
    return train_loader, val_loader, test_loader

In [None]:
def train_epoch(model, train_loader, optimizer, device, beta=1.0, epoch=0):
    model.train()
    epoch_loss = 0
    epoch_recon_loss = 0
    epoch_kl_loss = 0

    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1} - Training')
    for batch_idx, batch in enumerate(pbar):
        if isinstance(batch, (list, tuple)):
            if len(batch) >= 2:
                input_img, target_img = batch[0], batch[1]
            else:
                input_img = target_img = batch[0]
        elif isinstance(batch, dict):
            input_img  = batch.get('input', next(iter(batch.values())))
            target_img = batch.get('target', input_img)
        else:
            input_img = target_img = batch

        input_img  = input_img.to(device, non_blocking=True)
        target_img = target_img.to(device, non_blocking=True)

        optimizer.zero_grad()
        reconstruction, mu, log_var = model(input_img)
        loss, recon_loss, kl_loss = vae_loss(reconstruction, target_img, mu, log_var, beta)

        loss.backward()
        optimizer.step()

        batch_size = len(input_img)
        epoch_loss       += loss.item()
        epoch_recon_loss += recon_loss.item()
        epoch_kl_loss    += kl_loss.item()

        if batch_idx % 10 == 0:
            wandb.log({
                'batch/loss': loss.item() / batch_size,
                'batch/reconstruction_loss': recon_loss.item() / batch_size,
                'batch/kl_loss': kl_loss.item() / batch_size,
                'batch/step': epoch * len(train_loader) + batch_idx
            })

        pbar.set_postfix({
            'loss':  f'{loss.item()/batch_size:.4f}',
            'recon': f'{recon_loss.item()/batch_size:.4f}',
            'kl':    f'{kl_loss.item()/batch_size:.4f}'
        })

    num_samples = len(train_loader.dataset)
    return epoch_loss/num_samples, epoch_recon_loss/num_samples, epoch_kl_loss/num_samples


In [None]:
def validate(model, val_loader, device, beta=1.0, epoch=0):
    model.eval()
    val_loss = 0
    val_recon_loss = 0
    val_kl_loss = 0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f'Epoch {epoch+1} - Validation'):
            if isinstance(batch, (list, tuple)):
                if len(batch) >= 2:
                    input_img, target_img = batch[0], batch[1]
                else:
                    input_img = target_img = batch[0]
            elif isinstance(batch, dict):
                input_img  = batch.get('input', next(iter(batch.values())))
                target_img = batch.get('target', input_img)
            else:
                input_img = target_img = batch

            input_img  = input_img.to(device, non_blocking=True)
            target_img = target_img.to(device, non_blocking=True)

            reconstruction, mu, log_var = model(input_img)
            loss, recon_loss, kl_loss = vae_loss(reconstruction, target_img, mu, log_var, beta)

            val_loss       += loss.item()
            val_recon_loss += recon_loss.item()
            val_kl_loss    += kl_loss.item()

    num_samples = len(val_loader.dataset)
    return val_loss/num_samples, val_recon_loss/num_samples, val_kl_loss/num_samples

In [None]:
def visualize_results(model, val_loader, device, epoch=0):
    model.eval()
    
    batch = next(iter(val_loader))
    if isinstance(batch, (list, tuple)):
        if len(batch) >= 2:
            input_batch, target_batch = batch[0], batch[1]
        else:
            input_batch = target_batch = batch[0]
    elif isinstance(batch, dict):
        input_batch  = batch.get("input", next(iter(batch.values())))
        target_batch = batch.get("target", input_batch)
    else:
        input_batch = target_batch = batch

    with torch.no_grad():
        input_batch  = input_batch.to(device, non_blocking=True)
        target_batch = target_batch.to(device, non_blocking=True)

        recon_batch, _, _ = model(input_batch)
        num_vis = min(4, input_batch.shape[0])
        fig, axes = plt.subplots(3, num_vis, figsize=(4 * num_vis, 12))
        
        def to_img(t):
            x = t.detach().float().cpu()
            x = (x * 0.5) + 0.5
            x = x.clamp(0, 1)
            return x

        for i in range(num_vis):
            inp  = to_img(input_batch[i]).permute(1, 2, 0).numpy()
            rec  = to_img(recon_batch[i]).permute(1, 2, 0).numpy()
            targ = to_img(target_batch[i]).permute(1, 2, 0).numpy()

            axes[0, i].imshow(inp);  axes[0, i].set_title(f'Input {i+1}');          axes[0, i].axis('off')
            axes[1, i].imshow(rec);  axes[1, i].set_title(f'Reconstruction {i+1}'); axes[1, i].axis('off')
            axes[2, i].imshow(targ); axes[2, i].set_title(f'Ground Truth {i+1}');   axes[2, i].axis('off')

        plt.tight_layout()
        wandb.log({f"reconstructions/epoch_{epoch+1}": wandb.Image(fig), "epoch": epoch + 1})
        plt.close()

In [None]:
def train_model(model, train_loader, val_loader, device, num_epochs=100, 
                learning_rate=1e-3, beta=1.0, save_dir='checkpoints'):
    
    os.makedirs(save_dir, exist_ok=True)
    
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)
    
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []
    
    print("Starting Training")
    print(f"Device: {device}")
    print(f"Total epochs: {num_epochs}")
    print(f"Train samples: {len(train_loader.dataset)}")
    print(f"Val samples: {len(val_loader.dataset)}")
    print(f"Batch size: {train_loader.batch_size}")
    print(f"W&B Project: {wandb.run.project}")
    print(f"W&B Run: {wandb.run.name}")
    
    start_time = time.time()
    
    for epoch in range(num_epochs):
        epoch_start = time.time()
        
        
        train_loss, train_recon, train_kl = train_epoch(
            model, train_loader, optimizer, device, beta, epoch
        )
        
        val_loss, val_recon, val_kl = validate(model, val_loader, device, beta, epoch)
        old_lr = optimizer.param_groups[0]['lr']
        scheduler.step(val_loss)
        new_lr = optimizer.param_groups[0]['lr']
        
        train_losses.append((train_loss, train_recon, train_kl))
        val_losses.append((val_loss, val_recon, val_kl))
        epoch_time = time.time() - epoch_start
        
        wandb.log({
            'epoch': epoch + 1,
            'train/loss': train_loss,
            'train/reconstruction_loss': train_recon,
            'train/kl_loss': train_kl,
            'val/loss': val_loss,
            'val/reconstruction_loss': val_recon,
            'val/kl_loss': val_kl,
            'learning_rate': new_lr,
            'epoch_time': epoch_time
        })
        
        print(f"\nEpoch [{epoch+1:3d}/{num_epochs}] | Time: {epoch_time:.1f}s | LR: {new_lr:.2e}")
        print(f"  Train Loss: {train_loss:.6f} (Recon: {train_recon:.6f}, KL: {train_kl:.6f})")
        print(f"  Val   Loss: {val_loss:.6f} (Recon: {val_recon:.6f}, KL: {val_kl:.6f})")
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            checkpoint_path = os.path.join(save_dir, 'best_vae_model.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
            }, checkpoint_path)
            
            wandb.save(checkpoint_path)
            
            print(f"  ✓ Saved best model (val_loss: {val_loss:.6f})")
            
            artifact = wandb.Artifact('vae-model', type='model')
            artifact.add_file(checkpoint_path)
            wandb.log_artifact(artifact)
        
        if (epoch + 1) % 10 == 0:
            visualize_results(model, val_loader, device, epoch)
        
        if (epoch + 1) % 20 == 0:
            checkpoint_path = os.path.join(save_dir, f'checkpoint_epoch_{epoch+1}.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
            }, checkpoint_path)
    
    total_time = time.time() - start_time
    
    wandb.run.summary['best_val_loss'] = best_val_loss
    wandb.run.summary['total_training_time'] = total_time
    wandb.run.summary['total_epochs'] = num_epochs
    
    print(f"Training completed in {total_time/3600:.2f} hours")
    print(f"Best validation loss: {best_val_loss:.6f}")
    plot_losses(train_losses, val_losses)
    
    return train_losses, val_losses


In [None]:
def plot_losses(train_losses, val_losses):
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    epochs = range(1, len(train_losses) + 1)
    
    axes[0].plot(epochs, [l[0] for l in train_losses], label='Train', linewidth=2)
    axes[0].plot(epochs, [l[0] for l in val_losses], label='Val', linewidth=2)
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Total Loss')
    axes[0].set_title('Total Loss (Reconstruction + KL)')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    axes[1].plot(epochs, [l[1] for l in train_losses], label='Recon (Train)', linewidth=2)
    axes[1].plot(epochs, [l[2] for l in train_losses], label='KL (Train)', linewidth=2)
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Loss')
    axes[1].set_title('Reconstruction vs KL Divergence')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    wandb.log({"loss_curves": wandb.Image(fig)})
    
    plt.show()


In [None]:
class BallDetector:
    
    def __init__(self, ball_radius=15, min_radius=10, max_radius=25):
        self.ball_radius = ball_radius
        self.min_radius = min_radius
        self.max_radius = max_radius
    
    def detect_ball(self, image):
        
        if len(image.shape) == 3:
            gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        else:
            gray = image
        
        blurred = cv2.GaussianBlur(gray, (9, 9), 2)
        
        circles = cv2.HoughCircles(
            blurred, cv2.HOUGH_GRADIENT, dp=1, minDist=50,
            param1=50, param2=30,
            minRadius=self.min_radius, maxRadius=self.max_radius
        )
        
        if circles is None or len(circles[0]) == 0:
            return None
        
        circle = circles[0][0]
        center_x, center_y, radius = circle
        
        return {
            'position': (center_x, center_y),
            'radius': radius
        }

In [None]:
def evaluate_model(model, test_loader, device):
    
    model.eval()
    detector = BallDetector()

    all_errors_x = []
    all_errors_y = []
    failed_predictions = 0
    total_predictions = 0

    sample_predictions = []
    print("Evaluating Model")
    
    def denorm01(t):
        x = (t * 0.5) + 0.5
        return x.clamp(0, 1)

    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(test_loader, desc='Evaluating')):
            if isinstance(batch, (list, tuple)):
                if len(batch) >= 2:
                    input_img, target_img = batch[0], batch[1]
                else:
                    input_img = target_img = batch[0]
            elif isinstance(batch, dict):
                input_img  = batch.get('input', next(iter(batch.values())))
                target_img = batch.get('target', input_img)
            else:
                input_img = target_img = batch
                
            input_img  = input_img.to(device, non_blocking=True)
            target_img = target_img.to(device, non_blocking=True)

            reconstruction, _, _ = model(input_img)
            pred_imgs   = denorm01(reconstruction).cpu().numpy()
            target_imgs = denorm01(target_img).cpu().numpy()

            for i in range(len(pred_imgs)):
                total_predictions += 1
                pred   = (pred_imgs[i].transpose(1, 2, 0) * 255.0).round().astype(np.uint8)
                target = (target_imgs[i].transpose(1, 2, 0) * 255.0).round().astype(np.uint8)
                
                if len(sample_predictions) < 10:
                    fig, axes = plt.subplots(1, 2, figsize=(8, 4))
                    axes[0].imshow(pred)
                    axes[0].set_title('Prediction')
                    axes[0].axis('off')
                    axes[1].imshow(target)
                    axes[1].set_title('Ground Truth')
                    axes[1].axis('off')
                    plt.tight_layout()
                    sample_predictions.append(wandb.Image(fig))
                    plt.close()

                
                pred_ball   = detector.detect_ball(pred)
                target_ball = detector.detect_ball(target)

                if pred_ball is None or target_ball is None:
                    failed_predictions += 1
                    continue
                    
                pos_error_x = abs(pred_ball['position'][0] - target_ball['position'][0])
                pos_error_y = abs(pred_ball['position'][1] - target_ball['position'][1])

                all_errors_x.append(pos_error_x)
                all_errors_y.append(pos_error_y)

    
    if total_predictions == 0:
        failure_rate = 0.0
    else:
        failure_rate = (failed_predictions / total_predictions) * 100.0

    results = {
        'test/position_x_mean': float(np.mean(all_errors_x)) if all_errors_x else 0.0,
        'test/position_x_std':  float(np.std(all_errors_x))  if all_errors_x else 0.0,
        'test/position_y_mean': float(np.mean(all_errors_y)) if all_errors_y else 0.0,
        'test/position_y_std':  float(np.std(all_errors_y))  if all_errors_y else 0.0,
        'test/failure_rate':    failure_rate,
        'test/total_predictions': total_predictions,
        'test/failed_predictions': failed_predictions,
    }
    
    try:
        wandb.log(results)
        if sample_predictions:
            wandb.log({"test/sample_predictions": sample_predictions})
        comparison_table = wandb.Table(
            columns=["Metric", "Paper (VAE)", "Our Results"],
            data=[
                ["Position X Error (pixels)", "4.69 ± 6.1", f"{results['test/position_x_mean']:.2f} ± {results['test/position_x_std']:.2f}"],
                ["Position Y Error (pixels)", "6.25 ± 6.9", f"{results['test/position_y_mean']:.2f} ± {results['test/position_y_std']:.2f}"],
                ["Failure Rate (%)", "95%", f"{failure_rate:.2f}%"]
            ]
        )
        wandb.log({"comparison_with_paper": comparison_table})
    except Exception as e:
        print(f"(wandb logging skipped: {e})")
        
    print("EVALUATION RESULTS")
    print(f"\nTotal predictions: {total_predictions}")
    print(f"Failed predictions: {failed_predictions} ({failure_rate:.2f}%)")
    if all_errors_x:
        print(f"\nPosition X Error: {results['test/position_x_mean']:.2f} ± {results['test/position_x_std']:.2f} pixels")
        print(f"Position Y Error: {results['test/position_y_mean']:.2f} ± {results['test/position_y_std']:.2f} pixels")
        
    print("COMPARISON WITH PAPER RESULTS (Table 3)")
    print("Paper VAE Results:")
    print("  Position X: 4.69 ± 6.1 pixels")
    print("  Position Y: 6.25 ± 6.9 pixels")
    print("  Failure Rate: 95%")

    return results

In [None]:
def main(use_wandb=True, wandb_project="physicsgen-vae", wandb_entity=None):
    DATA_ROOT = '/kaggle/input/ball-data/ball_bounce/ball_bounce' 
    CONFIG = {
        'latent_dim': 128,
        'image_size': 256,
        'batch_size': 32,
        'num_epochs': 10,
        'learning_rate': 1e-3,
        'beta': 1.0,
        'num_workers': 2,
        'motion_type': 'bouncing',
    }
    
    if use_wandb:
        wandb.init(
            project=wandb_project,
            entity=wandb_entity,
            config=CONFIG,
            name=f"vae-{CONFIG['motion_type']}-latent{CONFIG['latent_dim']}",
            tags=['vae', 'physicsgen', CONFIG['motion_type']],
            notes="Replicating VAE baseline from PhysicsGen CVPR 2025"
        )
        config = wandb.config
    else:
        config = CONFIG
    
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    if device.type == 'cuda':
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
        if use_wandb:
            wandb.config.update({
                'gpu': torch.cuda.get_device_name(0),
                'gpu_memory_gb': torch.cuda.get_device_properties(0).total_memory / 1e9
            })
    
    train_loader, val_loader, test_loader = create_data_loaders(
        DATA_ROOT,
        batch_size=config['batch_size'],
        image_size=config['image_size'],
        num_workers=config['num_workers']
    )
    
    print("Initializing Model")
    model = VAE(
        latent_dim=config['latent_dim'],
        input_channels=3,
        image_size=config['image_size']
    ).to(device)
    
    num_params = sum(p.numel() for p in model.parameters())
    print(f"Model parameters: {num_params:,}")
    
    if use_wandb:
        wandb.config.update({'model_parameters': num_params})
        wandb.watch(model, log='all', log_freq=100)
    
    train_losses, val_losses = train_model(
        model, train_loader, val_loader, device,
        num_epochs=config['num_epochs'],
        learning_rate=config['learning_rate'],
        beta=config['beta']
    )
    
    
    results = evaluate_model(model, test_loader, device)
    if use_wandb:
        wandb.finish()
    
    print("\n✓ Training and evaluation complete!")
    print("✓ Check W&B dashboard for detailed logs and visualizations")
    
    return model, results

In [None]:
if __name__ == '__main__':
    
    DATA_ROOT = '/kaggle/input/ball-data/ball_bounce/ball_bounce'
    WANDB_PROJECT = "physicsgen-vae-replication"  # Your project name
    WANDB_ENTITY = None
    USE_WANDB = True 
    CONFIG_TEST = {
        'latent_dim': 128,
        'image_size': 256,
        'batch_size': 32,
        'num_epochs': 10,
        'learning_rate': 1e-3,
        'beta': 1.0,
        'num_workers': 2,
        'motion_type': 'bouncing',
    }
    
    print("PhysicsGen VAE Baseline - Kaggle Training Script")
    print(f"W&B Logging: {'Enabled' if USE_WANDB else 'Disabled'}")
    print(f"Data Path: {DATA_ROOT}")
    
    model, results = main(
        use_wandb=USE_WANDB,
        wandb_project=WANDB_PROJECT,
        wandb_entity=WANDB_ENTITY
    )