In [1]:
from dataset import DictDataset, RepeatedDictDataset
from model import *
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn
from importlib import reload
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR
import wandb
import torch
import os
from loss_functions import *
from inverse_warp import inverse_warp


In [2]:
from torch.utils.data import DataLoader

# dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
# sample = dataloader.__iter__().__next__()
# bigmodel = BigModel()
# pose_final, depth_map = bigmodel(sample)

In [3]:
repeatdataset = RepeatedDictDataset('./data/folder_0_pair_0.pt', 100)

In [4]:
import os
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn.functional as F

def train_model(bigmodel, 
                train_dataset, 
                val_dataset, 
                num_epochs=10, 
                batch_size=2,
                lr=1e-3,
                device='cpu',
                optimizer_cls=optim.Adam,
                patience=3,
                log_interval=10,
                save_dir="models",  # Directory to save the model
                save_name="best_model.pth"  # Model name to save
                ):
    """
    Train a model with photometric and smooth loss for depth estimation.
    
    Args:
        bigmodel: The PyTorch model to be trained.
        train_dataset: Training dataset.
        val_dataset: Validation dataset.
        num_epochs: Number of epochs to train.
        batch_size: Batch size for training.
        lr: Learning rate.
        device: Device to run the model ('cpu' or 'cuda').
        optimizer_cls: Optimizer class (e.g., torch.optim.Adam).
        patience: Early stopping patience.
        log_interval: Logging interval for progress display.
        save_dir: Directory to save the best model.
        save_name: Name for the saved model.
    """
    # Camera intrinsics
    intrinsics_flat = [9.569475e+02, 0.000000e+00, 6.939767e+02,
                       0.000000e+00, 9.522352e+02, 2.386081e+02,
                       0.000000e+00, 0.000000e+00, 1.000000e+00]

    # Move the model to the device
    bigmodel.to(device)

    # Set up data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    optimizer = optimizer_cls(bigmodel.parameters(), lr=lr)

    best_val_loss = float('inf')
    patience_counter = 0

    # Ensure the save directory exists
    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, save_name)

    # Training loop
    for epoch in range(num_epochs):
        bigmodel.train()
        train_loss = 0.0

        # Use tqdm for progress tracking
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        for i, sample in enumerate(progress_bar):
            # Zero the gradients
            optimizer.zero_grad()

            # Prepare inputs
            B = sample['image_t1']['processed_image'].shape[0]  # Batch size
            intrinsics_matrix = torch.tensor(intrinsics_flat).view(1, 3, 3).repeat(B, 1, 1).to(device)

            tgt_image = sample['image_t1']['processed_image'].to(device)  # Target image [B, 3, H, W]
            ref_image = sample['image_t']['processed_image'].to(device)  # Reference image [B, 3, H, W]

            # Forward pass
            pose, depth_map = bigmodel(sample)

            # Photometric reconstruction loss
            photometric_loss = photometric_reconstruction_loss(
                tgt_img=tgt_image,
                ref_img=ref_image,
                intrinsics=intrinsics_matrix,
                depth=depth_map.squeeze(1),  # Remove channel dimension [B, H, W]
                pose=pose
            )

            # Smooth loss
            if depth_map.dim() == 3:  # [B, H, W]
                depth_map = depth_map.unsqueeze(1)  # [B, 1, H, W]

            smoothness_loss = smooth_loss(depth_map)

            # Total loss
            loss = photometric_loss + 0.01 * smoothness_loss

            # Backward pass and optimizer step
            loss.backward()
            optimizer.step()

            # Accumulate train loss
            train_loss += loss.item()

            if i % log_interval == 0:
                avg_loss = train_loss / (i + 1)
                progress_bar.set_postfix(loss=avg_loss)

        # Validation loop (optional, uncomment if needed)
        # bigmodel.eval()
        # val_loss = 0.0
        # with torch.no_grad():
        #     for sample in val_loader:
        #         tgt_image = sample['image_t1']['processed_image'].to(device)
        #         ref_image = sample['image_t']['processed_image'].to(device)
        #         pose, depth_map = bigmodel(sample)
        #         photometric_loss = photometric_reconstruction_loss(
        #             tgt_img=tgt_image,
        #             ref_img=ref_image,
        #             intrinsics=intrinsics_matrix,
        #             depth=depth_map.squeeze(1),
        #             pose=pose
        #         )
        #         smoothness_loss = smooth_loss(depth_map)
        #         val_loss += (photometric_loss + 0.01 * smoothness_loss).item()
        #
        # val_loss /= len(val_loader)
        # if val_loss < best_val_loss:
        #     best_val_loss = val_loss
        #     patience_counter = 0
        #     torch.save(bigmodel.state_dict(), save_path)
        # else:
        #     patience_counter += 1

        # if patience_counter >= patience:
        #     print("Early stopping triggered!")
        #     break


In [5]:
big = BigModel()
train_model(bigmodel = big,
            train_dataset = repeatdataset,
            val_dataset = repeatdataset)

  data = torch.load(self.file_path)
Epoch 1/10: 100%|██████████| 50/50 [00:10<00:00,  4.94it/s, loss=0.0533]
Epoch 2/10:  88%|████████▊ | 44/50 [00:07<00:01,  5.98it/s, loss=0.0376]


KeyboardInterrupt: 