In [1]:
from dataset import DictDataset, RepeatedDictDataset
from model import *
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn
from importlib import reload
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR
import wandb
import torch
import os
from loss_functions import *
from inverse_warp import inverse_warp


In [2]:
from torch.utils.data import DataLoader

# dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
# sample = dataloader.__iter__().__next__()
# bigmodel = BigModel()
# pose_final, depth_map = bigmodel(sample)

In [3]:
repeatdataset = RepeatedDictDataset('./data/folder_0_pair_0.pt', 100)

In [4]:
def train_model(bigmodel, 
                train_dataset, 
                val_dataset, 
                num_epochs=1, 
                batch_size=2,
                lr=1e-3,
                device='cpu',
                optimizer=optim.Adam,
                criterion=smooth_loss,
                patience=3,
                log_interval=10,
                wandb_project='depth-estimation',
                wandb_run_name='default',
                save_dir="models",  # Directory to save the model
                save_name="best_model.pth"  # Model name to save
                ):
    
    # Initialize wandb
    # wandb.init(project=wandb_project, name=wandb_run_name, config={
    #     "epochs": num_epochs,
    #     "batch_size": batch_size,
    #     "learning_rate": lr,
    #     "device": device,
    #     "optimizer": optimizer.__name__,
    #     "loss_function": criterion.__class__.__name__ if criterion else "None",
    #     "early_stopping_patience": patience,
    # })
    
    intrinsics_flat = [9.569475e+02, 0.000000e+00, 6.939767e+02,
                       0.000000e+00, 9.522352e+02, 2.386081e+02,
                       0.000000e+00, 0.000000e+00, 1.000000e+00]
    
    bigmodel.to(device)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    optimizer = optimizer(bigmodel.parameters(), lr=lr)
    
    best_val_loss = float('inf')
    patience_counter = 0
    
    # Ensure the save directory exists
    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, save_name)
    
    # Training loop
    for epoch in range(num_epochs):
        bigmodel.train()
        train_loss = 0.0

        # Use tqdm for progress tracking
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        for i, sample in enumerate(progress_bar):
            # Zero the gradients
            optimizer.zero_grad()
            B = sample['image_t1']['processed_image'].shape[0]  # Example to get batch size
            intrinsics_matrix = torch.tensor(intrinsics_flat).view(1, 3, 3).repeat(B, 1, 1).to(device)

            # Forward pass
            pose, depth_map = bigmodel(sample)
            # B, H, W = batch_size, 224, 224  # Batch size, Height, Width
            # num_scales = 1         # Number of scales

            # depth_map = [torch.randn(B, 1, H // (2 ** s), W // (2)) for s in range(num_scales)]
            
            # loss = criterion(pred_depth=depth_map, 
            #                  pred_poses=pose, 
            #                  tgt_image=sample['image_t1']['processed_image'], 
            #                  src_image_stack=sample['image_t']['processed_image'],
            #                  intrinsics=intrinsics_matrix)
            loss = smooth_loss(pred_map=depth_map)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
            if i % log_interval == 0:
                avg_loss = train_loss / (i + 1)
                progress_bar.set_postfix(loss=avg_loss)
                print(avg_loss)
                # wandb.log({"epoch": epoch + 1, "batch_loss": avg_loss})
                
        # Validation loop
        # bigmodel.eval()
        # val_loss = 0.0
        # with torch.no_grad():
        #     for sample in val_loader:
        #         pose, depth_map = bigmodel(sample)
        #         loss = criterion(pred_depth=depth_map, 
        #                          pred_poses=pose, 
        #                          tgt_image=sample['image_t1']['processed_image'], 
        #                          src_image_stack=sample['image_t']['processed_image'],
        #                          intrinsics=intrinsics_matrix)
        #         val_loss += loss.item()
        
        # val_loss /= len(val_loader)
        # wandb.log({"epoch": epoch + 1, "val_loss": val_loss})
        
        # # Early Stopping Check
        # if val_loss < best_val_loss:
        #     best_val_loss = val_loss
        #     patience_counter = 0  # Reset patience counter
            
        #     # Save the best model
        #     torch.save(bigmodel.state_dict(), save_path)
        #     print(f"Model saved at {save_path}")
        #     wandb.log({"best_val_loss": best_val_loss, "model_saved": True, "save_path": save_path})
        # else:
        #     patience_counter += 1
        #     print(f"EarlyStopping: No improvement for {patience_counter} epoch(s).")

        # if patience_counter >= patience:
        #     print("Early stopping triggered!")
        #     break


In [5]:
big = BigModel()
train_model(bigmodel = big,
            train_dataset = repeatdataset,
            val_dataset = repeatdataset)

Epoch 1/1:  10%|█         | 5/50 [00:00<00:01, 24.97it/s, loss=0.00165]

0.0016510332934558392


Epoch 1/1:  32%|███▏      | 16/50 [00:00<00:01, 29.89it/s, loss=0.00153]

0.001525500768118284


Epoch 1/1:  48%|████▊     | 24/50 [00:00<00:00, 30.68it/s, loss=0.00105]

0.0010483377887534776


Epoch 1/1:  68%|██████▊   | 34/50 [00:01<00:00, 25.65it/s, loss=0.000808]

0.0008081146562847519


Epoch 1/1:  94%|█████████▍| 47/50 [00:01<00:00, 27.87it/s, loss=0.000652]

0.0006515655582117644


Epoch 1/1: 100%|██████████| 50/50 [00:01<00:00, 27.77it/s, loss=0.000652]


In [6]:
intrinsics_flat = [9.569475e+02, 0.000000e+00, 6.939767e+02,
                   0.000000e+00, 9.522352e+02, 2.386081e+02,
                   0.000000e+00, 0.000000e+00, 1.000000e+00]
# pose_final = pose_final.unsqueeze(1)
# #print(f"pose_final shape: {pose_final.shape}")  # Should be [B, 1, 6]
# # Convert to a 3x3 matrix
B = image_t.shape[0]  # Example to get batch size
intrinsics_matrix = torch.tensor(intrinsics_flat).view(1, 3, 3).repeat(B, 1, 1)  # Shape: [B, 3, 3]
# compute_loss(pred_depth=depth_map, pred_poses=pose_final, tgt_image=image_t1, src_image_stack=image_t,intrinsics=intrinsics_matrix)

NameError: name 'image_t' is not defined