# SZ

In [2]:
"""
Pix2Pix Implementation for PhysicsGen Motion Prediction
Based on paper appendix - Table 5 hyperparameters
"""

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import numpy as np
from PIL import Image
import os
import cv2
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import time
from pathlib import Path
import itertools

try:
    import wandb
except ImportError:
    print("Installing wandb...")
    !pip install wandb -q
    import wandb

In [3]:
os.environ['WANDB_API_KEY'] = 'f3579be751c0a30268d1d947d8e79f9ea3486905'
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mdhritik[0m ([33mdhritik-carnegie-mellon-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [4]:
class CombinedImageDataset(Dataset):
    def __init__(self, data_dir, transform=None, split_direction='horizontal'):
        self.data_dir = Path(data_dir)
        self.transform = transform
        self.split_direction = split_direction
        self.image_files = sorted([
            f for f in os.listdir(self.data_dir) 
            if f.endswith(('.png', '.jpg', '.jpeg'))
        ])
        
        print(f"{len(self.image_files)} combined images from {data_dir}")
        
        if self.image_files:
            first_img = Image.open(self.data_dir / self.image_files[0])
            w, h = first_img.size
            print(f"  Image size: {w}x{h}")
            
            if split_direction == 'horizontal':
                if w > h * 1.5:
                    print(f"  Detected horizontal split (width > height)")
                    print(f"  Will split into: {w//2}x{h} (left) and {w//2}x{h} (right)")
                else:
                    print(f"  Image might not be horizontally combined")
            else:
                if h > w * 1.5:
                    print(f"  ✓ Detected vertical split (height > width)")
                    print(f"  Will split into: {w}x{h//2} (top) and {w}x{h//2} (bottom)")
                else:
                    print(f"  Image might not be vertically combined")
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_path = self.data_dir / self.image_files[idx]
        combined_img = Image.open(img_path).convert('RGB')
        
        w, h = combined_img.size
        
        if self.split_direction == 'horizontal':
            mid = w // 2
            input_img = combined_img.crop((0, 0, mid, h))
            target_img = combined_img.crop((mid, 0, w, h))
        else:
            mid = h // 2
            input_img = combined_img.crop((0, 0, w, mid))
            target_img = combined_img.crop((0, mid, w, h))
        
        if self.transform:
            input_img = self.transform(input_img)
            target_img = self.transform(target_img)
        
        return input_img, target_img

In [5]:
def create_data_loaders(data_root, batch_size=18, image_size=256, num_workers=2, 
                       split_direction='horizontal'):
    
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])
    
    print(f"Split direction: {split_direction}")
    
    
    train_path = Path(data_root) / 'train_double'
    print("\nTrain set:")
    train_dataset = CombinedImageDataset(
        str(train_path), 
        transform=transform,
        split_direction=split_direction
    )
    
    val_path = Path(data_root) / 'val_double'
    test_path = Path(data_root) / 'test_double'
    
    if val_path.exists():
        print("\nValidation set:")
        val_dataset = CombinedImageDataset(
            str(val_path), 
            transform=transform,
            split_direction=split_direction
        )
    else:
        print("\nValidation set: Creating 10% split from training")
        train_size = int(0.9 * len(train_dataset))
        val_size = len(train_dataset) - train_size
        train_dataset, val_dataset = torch.utils.data.random_split(
            train_dataset, [train_size, val_size]
        )
    
    if test_path.exists():
        print("\nTest set:")
        test_dataset = CombinedImageDataset(
            str(test_path), 
            transform=transform,
            split_direction=split_direction
        )
    else:
        print("\nTest set: Using val set")
        test_dataset = val_dataset
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, 
                             shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, 
                           shuffle=False, num_workers=num_workers)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, 
                            shuffle=False, num_workers=num_workers)
    
    return train_loader, val_loader, test_loader

In [6]:
class UNetDown(nn.Module):
    def __init__(self, in_channels, out_channels, normalize=True, dropout=0.0):
        super(UNetDown, self).__init__()
        layers = [nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False)]
        if normalize:
            layers.append(nn.InstanceNorm2d(out_channels, affine=True))
        layers.append(nn.LeakyReLU(0.2))
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


class UNetUp(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=0.0):
        super(UNetUp, self).__init__()
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(out_channels, affine=True),
            nn.ReLU(inplace=True),
        ]
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x, skip_input):
        x = self.model(x)
        x = torch.cat((x, skip_input), 1)
        return x

In [7]:
class AddCoords(nn.Module):
    """
    Adds two extra channels (X and Y coordinates) to the input tensor.
    This gives the network 'spatial awareness', which is crucial for
    physics tasks like gravity or boundary detection.
    """
    def __init__(self, with_r=False):
        super().__init__()
        self.with_r = with_r

    def forward(self, input_tensor):
        """
        Args:
            input_tensor: shape (batch, channel, x_dim, y_dim)
        """
        batch_size, _, x_dim, y_dim = input_tensor.size()

        
        xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1)
        yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2)

      
        xx_channel = xx_channel.float() / (x_dim - 1)
        yy_channel = yy_channel.float() / (y_dim - 1)
        xx_channel = xx_channel * 2 - 1
        yy_channel = yy_channel * 2 - 1

       
        xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
        yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)

       
        xx_channel = xx_channel.type_as(input_tensor)
        yy_channel = yy_channel.type_as(input_tensor)

        ret = torch.cat([input_tensor, xx_channel, yy_channel], dim=1)

        if self.with_r:
            rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2))
            ret = torch.cat([ret, rr], dim=1)

        return ret

In [8]:
class Generator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(Generator, self).__init__()


        self.add_coords = AddCoords(with_r=False)
        self.down1 = UNetDown(in_channels + 2, 64, normalize=False) 
        self.down2 = UNetDown(64, 128)                          
        self.down3 = UNetDown(128, 256)                       
        self.down4 = UNetDown(256, 512)          
        self.down5 = UNetDown(512, 512)
        self.down6 = UNetDown(512, 512)
        self.down7 = UNetDown(512, 512)
        self.down8 = UNetDown(512, 512, normalize=False) 

        
        self.up1 = UNetUp(512, 512, dropout=0.5)
        self.up2 = UNetUp(1024, 512, dropout=0.5)
        self.up3 = UNetUp(1024, 512, dropout=0.5) 
        self.up4 = UNetUp(1024, 512)
        self.up5 = UNetUp(1024, 256)
        self.up6 = UNetUp(512, 128) 
        self.up7 = UNetUp(256, 64)

        self.final = nn.Sequential(
            nn.ConvTranspose2d(128, out_channels, 4, 2, 1),
            nn.Tanh(),
        )

    def forward(self, x):

        x = self.add_coords(x)
        
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)

        
        u1 = self.up1(d8, d7)
        u2 = self.up2(u1, d6)
        u3 = self.up3(u2, d5)
        u4 = self.up4(u3, d4)
        u5 = self.up5(u4, d3)
        u6 = self.up6(u5, d2)
        u7 = self.up7(u6, d1)

        return self.final(u7)

In [9]:
class Discriminator(nn.Module):
    
    def __init__(self, in_channels=6):
        super(Discriminator, self).__init__()
        self.add_coords = AddCoords(with_r=False)

        def discriminator_block(in_filters, out_filters, normalization=True):
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.InstanceNorm2d(out_filters, affine=True))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(in_channels+2, 64, normalization=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1, bias=False)
        )

    def forward(self, img_A, img_B):
        img_input = torch.cat((img_A, img_B), 1)
        img_input = self.add_coords(img_input)
        return self.model(img_input)

In [10]:
def compute_gradient_penalty(discriminator, real_A, real_B, fake_B, device, lambda_gp=10):
    
    batch_size = real_B.size(0)
    alpha = torch.rand(batch_size, 1, 1, 1, device=device)
    
    interpolates = (alpha * real_B + (1 - alpha) * fake_B).requires_grad_(True)
    d_interpolates = discriminator(real_A, interpolates)
    fake = torch.ones(d_interpolates.size(), device=device, requires_grad=False)
    
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lambda_gp
    
    return gradient_penalty

In [11]:
def train_epoch(generator, discriminator, train_loader, optimizer_G, optimizer_D, 
                criterion_GAN, criterion_L1, lambda_L1, lambda_gp, device, epoch):
    generator.train()
    discriminator.train()
    
    epoch_g_loss = 0
    epoch_d_loss = 0
    epoch_gan_loss = 0
    epoch_l1_loss = 0
    epoch_gp_loss = 0 
    
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1} - Training')
    for batch_idx, (real_A, real_B) in enumerate(pbar):
        real_A = real_A.to(device)
        real_B = real_B.to(device)
        
        batch_size = real_A.size(0)
        valid = torch.ones((batch_size, 1, 16, 16), device=device, requires_grad=False)
        fake = torch.zeros((batch_size, 1, 16, 16), device=device, requires_grad=False)
        
        optimizer_G.zero_grad()
        fake_B = generator(real_A)
        pred_fake = discriminator(real_A, fake_B)
        loss_GAN = criterion_GAN(pred_fake, valid)
        loss_L1 = criterion_L1(fake_B, real_B)
        loss_G = loss_GAN + lambda_L1 * loss_L1
        loss_G.backward()
        optimizer_G.step()
        
        optimizer_D.zero_grad()
        
        pred_real = discriminator(real_A, real_B)
        loss_real = criterion_GAN(pred_real, valid)
        
        pred_fake = discriminator(real_A, fake_B.detach())
        loss_fake = criterion_GAN(pred_fake, fake)
        
        gp = compute_gradient_penalty(discriminator, real_A, real_B, 
                                     fake_B.detach(), device, lambda_gp)
        
        loss_D = 0.5 * (loss_real + loss_fake) + gp
        loss_D.backward()
        optimizer_D.step()
        
        epoch_g_loss += loss_G.item()
        epoch_d_loss += loss_D.item()
        epoch_gan_loss += loss_GAN.item()
        epoch_l1_loss += loss_L1.item()
        epoch_gp_loss += gp.item()

        if batch_idx % 10 == 0:
            wandb.log({
                'batch/g_loss': loss_G.item(),
                'batch/d_loss': loss_D.item(),
                'batch/gan_loss': loss_GAN.item(),
                'batch/l1_loss': loss_L1.item(),
                'batch/gp_loss': gp.item(),
                'batch/step': epoch * len(train_loader) + batch_idx
            })
        
        pbar.set_postfix({
            'G': f'{loss_G.item():.4f}',
            'D': f'{loss_D.item():.4f}',
            'L1': f'{loss_L1.item():.4f}',
            'GP': f'{gp.item():.4f}'
        })
    
    num_batches = len(train_loader)
    return (epoch_g_loss/num_batches, epoch_d_loss/num_batches,
            epoch_gan_loss/num_batches, epoch_l1_loss/num_batches,
            epoch_gp_loss/num_batches)

In [12]:
def validate(generator, discriminator, val_loader, criterion_GAN, 
             criterion_L1, lambda_L1, device, epoch):
    generator.eval()
    discriminator.eval()
    
    val_g_loss = 0
    val_d_loss = 0
    
    with torch.no_grad():
        for real_A, real_B in tqdm(val_loader, desc=f'Epoch {epoch+1} - Validation'):
            real_A = real_A.to(device)
            real_B = real_B.to(device)
            
            batch_size = real_A.size(0)
            valid = torch.ones((batch_size, 1, 16, 16), device=device)
            fake = torch.zeros((batch_size, 1, 16, 16), device=device)
            
            fake_B = generator(real_A)
            pred_fake = discriminator(real_A, fake_B)
            loss_GAN = criterion_GAN(pred_fake, valid)
            loss_L1 = criterion_L1(fake_B, real_B)
            loss_G = loss_GAN + lambda_L1 * loss_L1
            
            pred_real = discriminator(real_A, real_B)
            loss_real = criterion_GAN(pred_real, valid)
            pred_fake = discriminator(real_A, fake_B)
            loss_fake = criterion_GAN(pred_fake, fake)
            loss_D = 0.5 * (loss_real + loss_fake)
            
            val_g_loss += loss_G.item()
            val_d_loss += loss_D.item()
    
    num_batches = len(val_loader)
    return val_g_loss/num_batches, val_d_loss/num_batches

In [13]:
def train_pix2pix(generator, discriminator, train_loader, val_loader, 
                  config, device, save_dir='checkpoints'):
    os.makedirs(save_dir, exist_ok=True)
    
    optimizer_G = optim.Adam(generator.parameters(), 
                            lr=config['lr_generator'], 
                            betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), 
                            lr=config['lr_discriminator'], 
                            betas=(0.5, 0.999))
    
    criterion_GAN = nn.MSELoss()
    criterion_L1 = nn.L1Loss()
    
    best_val_loss = float('inf')
    print("Starting Pix2Pix Training with Gradient Penalty")
    print("="*70)
    print(f"Lambda L1: {config['lambda_l1']}")
    print(f"Lambda GP: {config['lambda_gp']}")  # Show GP weight
    print("="*70)
    
    start_time = time.time()
    
    for epoch in range(config['num_epochs']):
        epoch_start = time.time()
        
        # Train with gradient penalty
        train_g, train_d, train_gan, train_l1, train_gp = train_epoch(
            generator, discriminator, train_loader,
            optimizer_G, optimizer_D,
            criterion_GAN, criterion_L1, 
            config['lambda_l1'], config['lambda_gp'],  # Pass lambda_gp
            device, epoch
        )
        
        # Validate
        val_g, val_d = validate(
            generator, discriminator, val_loader,
            criterion_GAN, criterion_L1, config['lambda_l1'],
            device, epoch
        )
        
        epoch_time = time.time() - epoch_start
        
        # Log to wandb (including GP)
        wandb.log({
            'epoch': epoch + 1,
            'train/g_loss': train_g,
            'train/d_loss': train_d,
            'train/gan_loss': train_gan,
            'train/l1_loss': train_l1,
            'train/gp_loss': train_gp,  # Log gradient penalty
            'val/g_loss': val_g,
            'val/d_loss': val_d,
            'epoch_time': epoch_time
        })
        
        print(f"\nEpoch [{epoch+1}/{config['num_epochs']}] | Time: {epoch_time:.1f}s")
        print(f"  Train - G: {train_g:.4f}, D: {train_d:.4f}, GP: {train_gp:.4f}")
        print(f"  Val   - G: {val_g:.4f}, D: {val_d:.4f}")
        
        if val_g < best_val_loss:
            best_val_loss = val_g
            checkpoint = {
                'epoch': epoch,
                'generator_state_dict': generator.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'val_loss': val_g,
            }
            torch.save(checkpoint, os.path.join(save_dir, 'best_pix2pix.pth'))
            wandb.save(os.path.join(save_dir, 'best_pix2pix.pth'))
            print(f"  ✓ Saved best model")
        
        # Visualize every 10 epochs
        if (epoch + 1) % 10 == 0:
            visualize_results(generator, val_loader, device, epoch)
    
    total_time = time.time() - start_time
    wandb.run.summary['best_val_loss'] = best_val_loss
    wandb.run.summary['total_training_time'] = total_time
    
    print("\n" + "="*70)
    print(f"Training completed in {total_time/3600:.2f} hours")
    print("="*70)

In [14]:
def visualize_results(generator, val_loader, device, epoch):
    generator.eval()
    
    with torch.no_grad():
        real_A, real_B = next(iter(val_loader))
        real_A = real_A.to(device)
        fake_B = generator(real_A)
        
        def denorm(x):
            return (x + 1) / 2
        
        num_vis = min(4, len(real_A))
        fig, axes = plt.subplots(3, num_vis, figsize=(4*num_vis, 12))
        
        for i in range(num_vis):
            input_img = denorm(real_A[i]).cpu().permute(1, 2, 0).numpy()
            axes[0, i].imshow(input_img)
            axes[0, i].set_title(f'Input {i+1}')
            axes[0, i].axis('off')
            
            gen_img = denorm(fake_B[i]).cpu().permute(1, 2, 0).numpy()
            axes[1, i].imshow(gen_img)
            axes[1, i].set_title(f'Generated {i+1}')
            axes[1, i].axis('off')
            
            target_img = denorm(real_B[i]).cpu().permute(1, 2, 0).numpy()
            axes[2, i].imshow(target_img)
            axes[2, i].set_title(f'Ground Truth {i+1}')
            axes[2, i].axis('off')
        
        plt.tight_layout()
        wandb.log({f"generations/epoch_{epoch+1}": wandb.Image(fig), "epoch": epoch + 1})
        plt.close()

In [15]:
class BallDetector:
    
    def __init__(self, ball_radius=15, min_radius=10, max_radius=25):
        self.ball_radius = ball_radius
        self.min_radius = min_radius
        self.max_radius = max_radius
    
    def detect_ball(self, image):
        if len(image.shape) == 3:
            gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        else:
            gray = image
        
        blurred = cv2.GaussianBlur(gray, (9, 9), 2)
        circles = cv2.HoughCircles(
            blurred, cv2.HOUGH_GRADIENT, dp=1, minDist=50,
            param1=50, param2=30,
            minRadius=self.min_radius, maxRadius=self.max_radius
        )
        
        if circles is None or len(circles[0]) == 0:
            return None
        
        circle = circles[0][0]
        return {'position': (circle[0], circle[1]), 'radius': circle[2]}

In [16]:
def evaluate_pix2pix(generator, test_loader, device):
    generator.eval()
    detector = BallDetector()
    
    all_errors_x = []
    all_errors_y = []
    failed_predictions = 0
    total_predictions = 0
    sample_predictions = []
    
    print("Evaluating Pix2Pix Model")
    
    with torch.no_grad():
        for batch_idx, (real_A, real_B) in enumerate(tqdm(test_loader, desc='Evaluating')):
            real_A = real_A.to(device)
            fake_B = generator(real_A)
            
            fake_imgs = ((fake_B + 1) / 2 * 255).cpu().numpy().astype(np.uint8)
            real_imgs = ((real_B + 1) / 2 * 255).numpy().astype(np.uint8)
            
            for i in range(len(fake_imgs)):
                total_predictions += 1
                
                pred = fake_imgs[i].transpose(1, 2, 0)
                target = real_imgs[i].transpose(1, 2, 0)
                
                # Log first 10 samples
                if len(sample_predictions) < 10:
                    fig, axes = plt.subplots(1, 2, figsize=(8, 4))
                    axes[0].imshow(pred)
                    axes[0].set_title('Prediction')
                    axes[0].axis('off')
                    axes[1].imshow(target)
                    axes[1].set_title('Ground Truth')
                    axes[1].axis('off')
                    plt.tight_layout()
                    sample_predictions.append(wandb.Image(fig))
                    plt.close()
                
                pred_ball = detector.detect_ball(pred)
                target_ball = detector.detect_ball(target)
                
                if pred_ball is None or target_ball is None:
                    failed_predictions += 1
                    continue
                
                pos_error_x = abs(pred_ball['position'][0] - target_ball['position'][0])
                pos_error_y = abs(pred_ball['position'][1] - target_ball['position'][1])
                
                all_errors_x.append(pos_error_x)
                all_errors_y.append(pos_error_y)
    
    failure_rate = (failed_predictions / total_predictions) * 100
    
    results = {
        'test/position_x_mean': np.mean(all_errors_x) if len(all_errors_x) > 0 else 0,
        'test/position_x_std': np.std(all_errors_x) if len(all_errors_x) > 0 else 0,
        'test/position_y_mean': np.mean(all_errors_y) if len(all_errors_y) > 0 else 0,
        'test/position_y_std': np.std(all_errors_y) if len(all_errors_y) > 0 else 0,
        'test/failure_rate': failure_rate,
        'test/total_predictions': total_predictions,
        'test/failed_predictions': failed_predictions
    }
    
    wandb.log(results)
    wandb.log({"test/sample_predictions": sample_predictions})
    
    comparison_table = wandb.Table(
        columns=["Metric", "Paper (Pix2Pix)", "Our Results"],
        data=[
            ["Position X Error (pixels)", "6.28 ± 7.98", f"{results['test/position_x_mean']:.2f} ± {results['test/position_x_std']:.2f}"],
            ["Position Y Error (pixels)", "11.7 ± 12.8", f"{results['test/position_y_mean']:.2f} ± {results['test/position_y_std']:.2f}"],
            ["Failure Rate (%)", "7%", f"{failure_rate:.2f}%"]
        ]
    )
    wandb.log({"comparison_with_paper": comparison_table})
    
    print("\n" + "="*70)
    print("EVALUATION RESULTS")
    print(f"Total: {total_predictions}, Failed: {failed_predictions} ({failure_rate:.2f}%)")
    
    if len(all_errors_x) > 0:
        print(f"\nPosition X Error: {results['test/position_x_mean']:.2f} ± {results['test/position_x_std']:.2f} pixels")
        print(f"Position Y Error: {results['test/position_y_mean']:.2f} ± {results['test/position_y_std']:.2f} pixels")
    
    print("\n" + "="*70)
    print("COMPARISON WITH PAPER RESULTS")
    print("Paper Pix2Pix Results:")
    print("  Position X: 6.28 ± 7.98 pixels")
    print("  Position Y: 11.7 ± 12.8 pixels")
    print("  Rotation: 17.2 ± 20.8 degrees")
    print("  Roundness: 0.56 ± 0.14 pixels")
    print("  Failure Rate: 7%")
    print("  Number of balls: 93% single ball, 7% no/multiple balls")
    
    return results

In [17]:
def main(use_wandb=True, wandb_project="physicsgen-pix2pix", wandb_entity=None):
    DATA_ROOT = '/kaggle/input/bounce-ball/bounce_ball'
    
    CONFIG = {
        'batch_size': 18,
        'lr_discriminator': 1e-4, #try changing to 5e-5
        'lr_generator': 2e-4,
        'num_epochs': 50,
        'lambda_l1': 100,
        'lambda_gp': 10,
        'image_size': 256,
        'num_workers': 2,
        'split_direction': 'horizontal',
    }
    
    if use_wandb:
        wandb.init(
            project=wandb_project,
            entity=wandb_entity,
            config=CONFIG,
            name=f"pix2pix-instancenorm-gp",
            tags=['pix2pix', 'instancenorm', 'gradient-penalty', 'physicsgen']
        )
        config = wandb.config
    else:
        config = CONFIG
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    train_loader, val_loader, test_loader = create_data_loaders(
        DATA_ROOT,
        batch_size=config['batch_size'],
        image_size=config['image_size'],
        num_workers=config['num_workers'],
        split_direction=config['split_direction']
    )
    
    generator = Generator().to(device)
    discriminator = Discriminator().to(device)
    print(f"\nGenerator params: {sum(p.numel() for p in generator.parameters()):,}")
    print(f"Discriminator params: {sum(p.numel() for p in discriminator.parameters()):,}")

    if use_wandb:
        wandb.watch(generator, log='all', log_freq=100)
        wandb.watch(discriminator, log='all', log_freq=100)
    
    train_pix2pix(generator, discriminator, train_loader, val_loader, config, device)
    
    if use_wandb:
        wandb.finish()
    
    return generator, discriminator

In [None]:
if __name__ == '__main__':
    DATA_ROOT = '/kaggle/input/bounce-ball/bounce_ball'
    USE_WANDB = True
    WANDB_PROJECT = "physicsgen-pix2pix"
    
    generator, discriminator = main(
        use_wandb=USE_WANDB,
        wandb_project=WANDB_PROJECT
    )

Using device: cuda
Split direction: horizontal

Train set:
44835 combined images from /kaggle/input/bounce-ball/bounce_ball/train_double
  Image size: 1024x512
  Detected horizontal split (width > height)
  Will split into: 512x512 (left) and 512x512 (right)

Validation set:
58 combined images from /kaggle/input/bounce-ball/bounce_ball/val_double
  Image size: 1024x512
  Detected horizontal split (width > height)
  Will split into: 512x512 (left) and 512x512 (right)

Test set:
1600 combined images from /kaggle/input/bounce-ball/bounce_ball/test_double
  Image size: 1024x512
  Detected horizontal split (width > height)
  Will split into: 512x512 (left) and 512x512 (right)

Generator params: 54,416,003
Discriminator params: 2,771,648
Starting Pix2Pix Training with Gradient Penalty
Lambda L1: 100
Lambda GP: 10


Epoch 1 - Training:   0%|          | 0/2491 [00:00<?, ?it/s]

Epoch 1 - Validation:   0%|          | 0/4 [00:00<?, ?it/s]


Epoch [1/50] | Time: 1125.3s
  Train - G: 15.5230, D: 0.5367, GP: 0.3276
  Val   - G: 15.0174, D: 0.1539
  ✓ Saved best model


Epoch 2 - Training:   0%|          | 0/2491 [00:00<?, ?it/s]

Epoch 2 - Validation:   0%|          | 0/4 [00:00<?, ?it/s]


Epoch [2/50] | Time: 1122.5s
  Train - G: 15.1867, D: 0.1467, GP: 0.0153
  Val   - G: 15.1047, D: 0.1190


Epoch 3 - Training:   0%|          | 0/2491 [00:00<?, ?it/s]

Epoch 3 - Validation:   0%|          | 0/4 [00:00<?, ?it/s]


Epoch [3/50] | Time: 1122.5s
  Train - G: 15.1894, D: 0.1243, GP: 0.0156
  Val   - G: 15.1567, D: 0.1038


Epoch 4 - Training:   0%|          | 0/2491 [00:00<?, ?it/s]

Epoch 4 - Validation:   0%|          | 0/4 [00:00<?, ?it/s]


Epoch [4/50] | Time: 1122.6s
  Train - G: 15.1733, D: 0.1310, GP: 0.0215
  Val   - G: 15.1090, D: 0.1010


Epoch 5 - Training:   0%|          | 0/2491 [00:00<?, ?it/s]

Epoch 5 - Validation:   0%|          | 0/4 [00:00<?, ?it/s]


Epoch [5/50] | Time: 1122.2s
  Train - G: 15.1796, D: 0.1149, GP: 0.0148
  Val   - G: 15.1499, D: 0.0963


Epoch 6 - Training:   0%|          | 0/2491 [00:00<?, ?it/s]

Epoch 6 - Validation:   0%|          | 0/4 [00:00<?, ?it/s]


Epoch [6/50] | Time: 1124.0s
  Train - G: 15.1851, D: 0.1074, GP: 0.0122
  Val   - G: 15.1404, D: 0.0948


Epoch 7 - Training:   0%|          | 0/2491 [00:00<?, ?it/s]

Epoch 7 - Validation:   0%|          | 0/4 [00:00<?, ?it/s]


Epoch [7/50] | Time: 1123.1s
  Train - G: 15.1783, D: 0.1096, GP: 0.0132
  Val   - G: 15.2285, D: 0.0900


Epoch 8 - Training:   0%|          | 0/2491 [00:00<?, ?it/s]

Epoch 8 - Validation:   0%|          | 0/4 [00:00<?, ?it/s]


Epoch [8/50] | Time: 1122.3s
  Train - G: 15.1781, D: 0.1048, GP: 0.0112
  Val   - G: 15.2383, D: 0.0904


Epoch 9 - Training:   0%|          | 0/2491 [00:00<?, ?it/s]

Epoch 9 - Validation:   0%|          | 0/4 [00:00<?, ?it/s]


Epoch [9/50] | Time: 1123.7s
  Train - G: 15.1772, D: 0.0976, GP: 0.0080
  Val   - G: 15.2997, D: 0.0921


Epoch 10 - Training:   0%|          | 0/2491 [00:00<?, ?it/s]

Epoch 10 - Validation:   0%|          | 0/4 [00:00<?, ?it/s]


Epoch [10/50] | Time: 1122.5s
  Train - G: 15.1630, D: 0.1065, GP: 0.0119
  Val   - G: 15.2015, D: 0.0885


Epoch 11 - Training:   0%|          | 0/2491 [00:00<?, ?it/s]

Epoch 11 - Validation:   0%|          | 0/4 [00:00<?, ?it/s]


Epoch [11/50] | Time: 1119.7s
  Train - G: 15.1615, D: 0.0976, GP: 0.0077
  Val   - G: 15.3039, D: 0.0915


Epoch 12 - Training:   0%|          | 0/2491 [00:00<?, ?it/s]

Epoch 12 - Validation:   0%|          | 0/4 [00:00<?, ?it/s]


Epoch [12/50] | Time: 1121.4s
  Train - G: 15.1566, D: 0.0961, GP: 0.0073
  Val   - G: 15.2009, D: 0.0843


Epoch 13 - Training:   0%|          | 0/2491 [00:00<?, ?it/s]