In [None]:
# Cell 1: Import required libraries
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path
import glob

# PyTorch3D imports
from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
    look_at_view_transform,
    FoVPerspectiveCameras, 
    PointLights, 
    DirectionalLights, 
    Materials, 
    RasterizationSettings, 
    MeshRenderer, 
    MeshRasterizer,  
    SoftPhongShader,
    TexturesVertex
)
from pytorch3d.io import load_obj, save_obj

# Set the device
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")
    print("WARNING: CPU only, this will be slow!")

In [None]:
# Cell 2: Load and preprocess images
def load_images_from_folder(folder_path):
    """
    Load all images from the specified folder and resize them to a consistent size
    """
    target_size = (224, 224)  # We can adjust this size based on our needs
    images = []
    
    # Get all image files from the folder
    image_paths = glob.glob(os.path.join(folder_path, "*.[jJ][pP][gG]")) + \
                 glob.glob(os.path.join(folder_path, "*.[pP][nN][gG]"))
    
    for img_path in image_paths:
        try:
            # Open and convert to RGB (in case of RGBA or grayscale)
            img = Image.open(img_path).convert('RGB')
            
            # Resize image while maintaining aspect ratio
            img.thumbnail(target_size, Image.Resampling.LANCZOS)
            
            # Convert to tensor and normalize to [0, 1]
            img_tensor = torch.FloatTensor(np.array(img)) / 255.0
            images.append(img_tensor)
            
        except Exception as e:
            print(f"Error loading image {img_path}: {str(e)}")
    
    return images

# Load images from your media folder
media_path = "./media"  # Update this path to your media folder location
input_images = load_images_from_folder(media_path)

print(f"Loaded {len(input_images)} images")

# Display some of the loaded images
fig, axes = plt.subplots(1, min(5, len(input_images)), figsize=(15, 3))
for i, img_tensor in enumerate(input_images[:5]):
    if len(input_images) > 1:
        ax = axes[i]
    else:
        ax = axes
    ax.imshow(img_tensor)
    ax.axis('off')
plt.show()

In [None]:
# Cell 3: Enhanced image loading and preprocessing
def analyze_image_dataset(images):
    """Analyze the loaded images and print useful statistics"""
    n_images = len(images)
    image_sizes = [img.shape for img in images]
    unique_sizes = set(str(size) for size in image_sizes)
    
    print(f"Dataset statistics:")
    print(f"Number of images: {n_images}")
    print(f"Unique image sizes: {len(unique_sizes)}")
    for size in unique_sizes:
        count = sum(1 for img_size in image_sizes if str(img_size) == size)
        print(f"  {size}: {count} images")
    
    # Calculate mean and std for normalization
    all_images = torch.stack([img for img in images])
    mean = all_images.mean(dim=[0, 1, 2])
    std = all_images.std(dim=[0, 1, 2])
    
    return mean, std

# Analyze the loaded dataset
mean, std = analyze_image_dataset(input_images)

# Create a subset for initial testing (optional)
n_test_images = 20  # Adjust based on your computational resources
test_indices = torch.linspace(0, len(input_images)-1, n_test_images).long()
test_images = [input_images[i] for i in test_indices]

# Display mean values per channel
print("\nChannel-wise mean values:")
print(f"R: {mean[0]:.3f}, G: {mean[1]:.3f}, B: {mean[2]:.3f}")

In [None]:
# Cell 4 (Updated): Initialize base mesh and renderer configuration
def create_base_mesh(device):
    """
    Create a simple spherical mesh as our starting point for deformation.
    """
    from pytorch3d.utils import ico_sphere
    
    # Create sphere of radius 1
    base_mesh = ico_sphere(4, device) # Level 4 subdivision for reasonable detail
    
    # Initialize vertex colors/texture for visualization
    verts = base_mesh.verts_padded()
    N = verts.shape[1]
    verts_rgb = torch.ones((1, N, 3), device=device)  # White color base
    textures = TexturesVertex(verts_features=verts_rgb)
    
    base_mesh.textures = textures
    return base_mesh

# Initialize our renderer
def create_renderer(image_size=(126, 224), device=device):
    """
    Create a renderer with our desired settings
    """
    # Camera settings
    cameras = FoVPerspectiveCameras(device=device)
    
    # Rasterization settings
    raster_settings = RasterizationSettings(
        image_size=image_size,  # Match input image dimensions
        blur_radius=0.0,
        faces_per_pixel=1,
    )
    
    # Lighting
    lights = PointLights(
        device=device,
        location=[[0.0, 0.0, -3.0]],
        ambient_color=((0.5, 0.5, 0.5),),
        diffuse_color=((0.3, 0.3, 0.3),),
        specular_color=((0.2, 0.2, 0.2),),
    )
    
    renderer = MeshRenderer(
        rasterizer=MeshRasterizer(
            cameras=cameras, 
            raster_settings=raster_settings
        ),
        shader=SoftPhongShader(
            device=device, 
            cameras=cameras,
            lights=lights
        )
    )
    
    return renderer

# Create base mesh and renderer
base_mesh = create_base_mesh(device)
renderer = create_renderer(image_size=(126, 224), device=device)

# Visualize initial mesh from a few viewpoints
def visualize_mesh(mesh, renderer, num_views=3):
    elevs = torch.linspace(0, 360, num_views)
    azims = torch.linspace(-180, 180, num_views)
    
    images = []
    for elev, azim in zip(elevs, azims):
        R, T = look_at_view_transform(dist=2.7, elev=elev, azim=azim)
        cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
        
        image = renderer(mesh, cameras=cameras)
        images.append(image[0, ..., :3].cpu().detach())
    
    # Display results
    fig, axes = plt.subplots(1, num_views, figsize=(15, 5))
    for i, img in enumerate(images):
        if num_views > 1:
            ax = axes[i]
        else:
            ax = axes
        ax.imshow(img)
        ax.axis('off')
    plt.show()

# Visualize our starting point
visualize_mesh(base_mesh, renderer)

In [None]:
# Cell 5 (Updated): Define the mesh deformation model and losses
class MeshDeformationModel(torch.nn.Module):
    def __init__(self, base_mesh, device=device):
        super().__init__()
        
        # Get initial vertices and create learnable offsets
        self.initial_verts = base_mesh.verts_packed().clone()
        self.deform_verts = torch.nn.Parameter(
            torch.zeros_like(self.initial_verts, device=device)
        )
        
        # Keep faces fixed - we only deform vertices
        self.register_buffer('faces', base_mesh.faces_packed())
        
        # Initialize vertex colors as learnable parameters
        N = len(self.initial_verts)
        self.vertex_colors = torch.nn.Parameter(
            torch.ones((N, 3), device=device) * 0.5
        )

    def forward(self):
        # Apply deformation to vertices
        deformed_verts = self.initial_verts + self.deform_verts
        
        # Create texture from vertex colors
        textures = TexturesVertex(
            verts_features=self.vertex_colors[None]
        )
        
        # Create and return mesh
        return Meshes(
            verts=[deformed_verts],
            faces=[self.faces],
            textures=textures
        )

class ReconstructionLoss:
    def __init__(self, renderer):
        self.renderer = renderer
    
    def rgb_loss(self, pred_rgb, target_rgb):
        """Simple L1 loss between RGB images"""
        # Ensure both tensors are on the same device
        target_rgb = target_rgb.to(pred_rgb.device)
        return torch.abs(pred_rgb - target_rgb).mean()
    
    def silhouette_loss(self, pred_silhouette, target_silhouette):
        """Binary cross entropy loss for silhouettes"""
        # Ensure both tensors are on the same device
        target_silhouette = target_silhouette.to(pred_silhouette.device)
        return torch.nn.functional.binary_cross_entropy(
            pred_silhouette.clamp(min=0.0, max=1.0),
            target_silhouette
        )
    
    def mesh_regularization(self, mesh):
        """Regularization to encourage smooth mesh"""
        from pytorch3d.loss import mesh_laplacian_smoothing
        return mesh_laplacian_smoothing(mesh)
    
    def edge_regularization(self, mesh):
        """Regularization to prevent long edges"""
        from pytorch3d.loss import mesh_edge_loss
        return mesh_edge_loss(mesh)

    def compute_loss(self, mesh, target_image, camera):
        """Compute full loss for a single view"""
        # Render predicted image
        pred_image = self.renderer(mesh, cameras=camera)
        
        # Extract RGB and silhouette
        pred_rgb = pred_image[..., :3]
        pred_silhouette = pred_image[..., 3]
        
        # Convert target image to silhouette (simple threshold)
        target_silhouette = (target_image.mean(dim=-1) > 0.1).float()
        
        # Compute losses
        rgb_loss = self.rgb_loss(pred_rgb[0], target_image)
        sil_loss = self.silhouette_loss(pred_silhouette[0], target_silhouette)
        lap_loss = self.mesh_regularization(mesh)
        edge_loss = self.edge_regularization(mesh)
        
        # Combine losses with weights
        total_loss = (
            1.0 * rgb_loss + 
            1.0 * sil_loss + 
            0.1 * lap_loss + 
            0.1 * edge_loss
        )
        
        return total_loss, {
            'rgb_loss': rgb_loss.item(),
            'silhouette_loss': sil_loss.item(),
            'laplacian_loss': lap_loss.item(),
            'edge_loss': edge_loss.item(),
            'total_loss': total_loss.item()
        }

# Initialize model and loss
model = MeshDeformationModel(base_mesh, device=device)
loss_fn = ReconstructionLoss(renderer)

# Test forward pass and loss computation
test_mesh = model()
R, T = look_at_view_transform(dist=2.7, elev=0, azim=0)
test_camera = FoVPerspectiveCameras(device=device, R=R, T=T)
test_loss, loss_dict = loss_fn.compute_loss(test_mesh, input_images[0], test_camera)

print("Initial test loss values:")
for k, v in loss_dict.items():
    print(f"{k}: {v:.4f}")

# Visualize initial state
visualize_mesh(test_mesh, renderer)

In [None]:
# Cell 7: Training loop with CPU optimizations
from collections import defaultdict
import time
import tqdm.notebook

class TrainingManager:
    def __init__(self, model, loss_fn, optimizer, input_images, 
                 save_dir='./checkpoints', device=device):
        self.model = model
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.input_images = input_images
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(exist_ok=True)
        self.device = device
        
        # Training history
        self.history = defaultdict(list)
        self.best_loss = float('inf')
        
    def train_step(self, target_images, n_views=2):
        """Single training step with memory-efficient view sampling"""
        self.model.train()
        
        # Sample a subset of views for this step
        elevs = torch.linspace(0, 360, n_views)
        azims = torch.linspace(-180, 180, n_views)
        
        # Get current mesh
        mesh = self.model()
        
        # Initialize total loss
        total_loss = 0
        all_losses = defaultdict(float)
        
        # Process each view sequentially to save memory
        for view_idx, (elev, azim) in enumerate(zip(elevs, azims)):
            # Create camera for this view
            R, T = look_at_view_transform(dist=2.7, elev=elev, azim=azim)
            camera = FoVPerspectiveCameras(device=self.device, R=R, T=T)
            
            # Get target image
            target_idx = view_idx % len(target_images)
            target = target_images[target_idx]
            
            # Compute loss for this view
            loss, losses = self.loss_fn.compute_loss(mesh, target, camera)
            total_loss += loss / n_views
            
            # Accumulate losses
            for k, v in losses.items():
                all_losses[k] += v / n_views
        
        # Backward pass
        self.optimizer.zero_grad()
        total_loss.backward()
        self.optimizer.step()
        
        return dict(all_losses)
    
    def train(self, n_iterations=1000, batch_size=4, 
              log_every=10, save_every=100):
        """Main training loop with progress tracking"""
        
        # Create progress bar
        pbar = tqdm.notebook.trange(n_iterations)
        start_time = time.time()
        
        try:
            for iteration in pbar:
                # Select random subset of images for this iteration
                batch_indices = torch.randperm(len(self.input_images))[:batch_size]
                batch_images = [self.input_images[i] for i in batch_indices]
                
                # Training step
                losses = self.train_step(batch_images)
                
                # Update history
                for k, v in losses.items():
                    self.history[k].append(v)
                
                # Update progress bar
                pbar.set_postfix({
                    'loss': f"{losses['total_loss']:.4f}",
                    'rgb_loss': f"{losses['rgb_loss']:.4f}"
                })
                
                # Periodic logging
                if iteration % log_every == 0:
                    elapsed = time.time() - start_time
                    print(f"\nIteration {iteration}")
                    print(f"Time elapsed: {elapsed:.1f}s")
                    for k, v in losses.items():
                        print(f"{k}: {v:.4f}")
                    
                    # Visualize current state
                    with torch.no_grad():
                        current_mesh = self.model()
                        visualize_mesh(current_mesh, renderer)
                
                # Save checkpoint if best so far or periodic
                if (losses['total_loss'] < self.best_loss or 
                    iteration % save_every == 0):
                    self.save_checkpoint(iteration, losses)
                    if losses['total_loss'] < self.best_loss:
                        self.best_loss = losses['total_loss']
        
        except KeyboardInterrupt:
            print("\nTraining interrupted by user")
        
        return self.history
    
    def save_checkpoint(self, iteration, losses):
        """Save model checkpoint"""
        checkpoint = {
            'iteration': iteration,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'losses': dict(self.history),
            'best_loss': self.best_loss
        }
        path = self.save_dir / f'checkpoint_{iteration:04d}.pt'
        torch.save(checkpoint, path)
        print(f"\nSaved checkpoint to {path}")

# Initialize training
optimizer = torch.optim.Adam([
    {'params': model.deform_verts, 'lr': 1e-3},
    {'params': model.vertex_colors, 'lr': 1e-3}
])

trainer = TrainingManager(
    model=model,
    loss_fn=loss_fn,
    optimizer=optimizer,
    input_images=input_images,
    save_dir='./reconstruction_checkpoints'
)

# Start training
history = trainer.train(
    n_iterations=1000,  # Reduced for CPU
    batch_size=4,      # Small batch size for CPU
    log_every=20,
    save_every=100
)