In [1]:
from posenet.model import PoseNetDinoImproved
from nflownet.model import NFlowNet
import torch
import torch.nn as nn
import torch.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
from dataset.tartanair import TartanAirDataset
import random
import numpy as np
from cheirality.cheiralityLayer import CheiralityLayer

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed = 42
set_seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [2]:
def load_posenet(posenet_path):
    # Initialize FlowNet model
    posenet = PoseNetDinoImproved().to(device)
    
    # Load weights from .pth file
    checkpoint = torch.load(posenet_path, map_location=device)
    
    # Handle different save formats
    if 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
    elif 'model_state_dict' in checkpoint:
        state_dict = checkpoint['model_state_dict']
    else:
        state_dict = checkpoint
    
    # Load state dict into model
    posenet.load_state_dict(state_dict)
    posenet.train()  
    
    return posenet

In [3]:
def load_nflownet(nflownet_path):
    # Initialize FlowNet model
    nflownet = NFlowNet().to(device)
    
    # Load weights from .pth file
    checkpoint = torch.load(nflownet_path, map_location=device)
    
    # Handle different save formats
    if 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
    elif 'model_state_dict' in checkpoint:
        state_dict = checkpoint['model_state_dict']
    else:
        state_dict = checkpoint
    
    # Load state dict into model
    nflownet.load_state_dict(state_dict)
    nflownet.eval()  # Set to evaluation mode
    
    return nflownet

In [4]:
posenet_path = r"D:\KOC UNIVERSITY\COMP447\trainedmodels\posenet\posenet2.pth"
nflownet_path = r"D:\KOC UNIVERSITY\COMP447\trainedmodels\nflownet\nflownet_final.pth"

posenet = load_posenet(posenet_path)
nflownet = load_nflownet(nflownet_path)

Using cache found in C:\Users\emircan/.cache\torch\hub\facebookresearch_dinov2_main
  checkpoint = torch.load(posenet_path, map_location=device)


In [5]:
cheirality_layer = CheiralityLayer(nflownet=nflownet,posenet=posenet).to(device)

In [6]:
# Dataset and DataLoader
dataset = TartanAirDataset(root_dir="D:/KOC UNIVERSITY/COMP447/tartanair_dataset",seq_len=6)
train_loader = DataLoader(dataset, batch_size=1, shuffle=True)

In [7]:
optimizer = optim.Adam(posenet.parameters(), lr=1e-4)

num_epochs = 1
for epoch in range(num_epochs):
    total_loss = 0.0
    
    for images, translation, rotation in train_loader:
        images = images.to(device)
        loss = cheirality_layer(images)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        print(loss.item())
    print(f"Epoch {epoch}, Loss: {total_loss / len(train_loader)}")
        

Upper level loss:
tensor(2.5276e+17, device='cuda:0', grad_fn=<MseLossBackward0>)
-1.6828477860520002e+29
Upper level loss:
tensor(7.0259e+23, device='cuda:0', grad_fn=<MseLossBackward0>)
nan


KeyboardInterrupt: 