In [1]:
from posenet.model import PoseNetDino
from nflownet.model import NFlowNet
import torch
import torch.nn as nn
import torch.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
from dataset.tartanair import TartanAirDataset
import random
import numpy as np
from cheirality.cheiralityLayer import CheiralityLayer

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed = 42
set_seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset and DataLoader
dataset = TartanAirDataset(root_dir="D:/KOC UNIVERSITY/COMP447/data/image_left")
train_loader = DataLoader(dataset, batch_size=1, shuffle=True)


In [2]:
from safetensors import safe_open
def load_posenet(posenet_path):
    # Initialize PoseNet model
    posenet = PoseNetDino().to(device)
    
    # Load weights from .safetensors file
    with safe_open(posenet_path, framework="pt", device="cpu") as f:
        state_dict = {}
        for key in f.keys():
            state_dict[key] = f.get_tensor(key)
    
    # Load state dict into model
    posenet.load_state_dict(state_dict)
    posenet.train()
    
    return posenet

In [3]:
def load_nflownet(nflownet_path):
    # Initialize FlowNet model
    nflownet = NFlowNet().to(device)
    
    # Load weights from .pth file
    checkpoint = torch.load(nflownet_path, map_location=device)
    
    # Handle different save formats
    if 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
    elif 'model_state_dict' in checkpoint:
        state_dict = checkpoint['model_state_dict']
    else:
        state_dict = checkpoint
    
    # Load state dict into model
    nflownet.load_state_dict(state_dict)
    nflownet.eval()  # Set to evaluation mode
    
    return nflownet

In [4]:
posenet_path = r"D:\KOC UNIVERSITY\COMP447\trainedmodels\posenet\model.safetensors"
nflownet_path = r"D:\KOC UNIVERSITY\COMP447\trainedmodels\nflownet\nflownet_final.pth"

#posenet = load_posenet(posenet_path
 
posenet = PoseNetDino().to(device)
nflownet = load_nflownet(nflownet_path)

Using cache found in C:\Users\emircan/.cache\torch\hub\facebookresearch_dinov2_main
  checkpoint = torch.load(nflownet_path, map_location=device)


In [5]:
cheirality_layer = CheiralityLayer(nflownet=nflownet,posenet=posenet).to(device)

In [6]:
optimizer = optim.Adam(posenet.parameters(), lr=1e-4)

num_epochs = 1
for epoch in range(num_epochs):
    total_loss = 0.0
    
    for images, _, _ in train_loader:
        images = images.to(device)
        loss = cheirality_layer(images)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        print(loss.item())
    print(f"Epoch {epoch}, Loss: {total_loss / len(train_loader)}")
        

2.1275472956413394e+26
nan
nan
nan
nan


KeyboardInterrupt: 