In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import numpy as np
import cv2
from PIL import Image
import sys
from argparse import Namespace
from collections import OrderedDict
import glob
import random

sys.path.append('core')
from raft import RAFT
from utils.utils import InputPadder

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Dataset class for DNS and Uniform

def read_flo_file(path):
    with open(path, 'rb') as f:
        magic = np.fromfile(f, np.float32, count=1)[0]
        if magic != 202021.25:
            raise ValueError('Magic number incorrect. Invalid .flo file')
        w = np.fromfile(f, np.int32, count=1)[0]
        h = np.fromfile(f, np.int32, count=1)[0]
        data = np.fromfile(f, np.float32, count=2*w*h)
        return np.resize(data, (h, w, 2))

class FlowDataset(data.Dataset):
    def __init__(self, dns_dir, uniform_dir):
        self.samples = []

        dns_flows = sorted(glob.glob(os.path.join(dns_dir, '*_flow.flo')))[:1000]
        for flow_path in dns_flows:
            base = os.path.basename(flow_path).replace('_flow.flo', '')
            img1_path = os.path.join(dns_dir, f'{base}_img1.tif')
            img2_path = os.path.join(dns_dir, f'{base}_img2.tif')
            self.samples.append((img1_path, img2_path, flow_path))

        uniform_flows = sorted(glob.glob(os.path.join(uniform_dir, '*_flow.flo')))[:1000]
        for flow_path in uniform_flows:
            base = os.path.basename(flow_path).replace('_flow.flo', '')
            img1_path = os.path.join(uniform_dir, f'{base}_img1.tif')
            img2_path = os.path.join(uniform_dir, f'{base}_img2.tif')
            self.samples.append((img1_path, img2_path, flow_path))

        random.shuffle(self.samples)

    def __len__(self):
        return len(self.samples)

    def preprocess(self, path):
        img = np.array(Image.open(path))
        if img.ndim == 2:  # grayscale
            img = np.stack([img]*3, axis=-1)
        img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0
        return img

    def __getitem__(self, idx):
        img1_path, img2_path, flow_path = self.samples[idx]
        img1 = self.preprocess(img1_path)
        img2 = self.preprocess(img2_path)
        flow = read_flo_file(flow_path)
        flow = torch.from_numpy(flow).permute(2, 0, 1).float()
        return img1, img2, flow

# RAFT model wrapper
class RAFTWrapper(nn.Module):
    def __init__(self, args):
        super(RAFTWrapper, self).__init__()
        self.raft = RAFT(args)

    def forward(self, image1, image2):
        padder = InputPadder(image1.shape)
        image1, image2 = padder.pad(image1, image2)
        _, flow_up = self.raft(image1, image2, iters=12, test_mode=True)
        return flow_up

# Training loop with gradual unfreezing
if __name__ == '__main__':
    args = Namespace(
        small=False,
        mixed_precision=False,
        alternate_corr=False,
        model='/Users/edasaruhan21/RAFT/models/raft-sintel.pth'
    )

    dns_dir = '/Users/edasaruhan21/DNS'
    uniform_dir = '/Users/edasaruhan21/uniform'

    model = RAFTWrapper(args).to(DEVICE)

    # Load pretrained weights and remove 'module.' prefix if needed
    state_dict = torch.load(args.model, map_location=DEVICE)
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k.replace('module.', '')
        new_state_dict[name] = v
    model.raft.load_state_dict(new_state_dict)

    # Freeze all layers initially
    for param in model.raft.parameters():
        param.requires_grad = False

    # Unfreeze only the update block in the beginning
    for name, param in model.raft.named_parameters():
        if "update_block" in name:
            param.requires_grad = True

    dataset = FlowDataset(dns_dir, uniform_dir)
    dataloader = data.DataLoader(dataset, batch_size=16, shuffle=True)

    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-5)
    criterion_flow = nn.MSELoss()

    model.train()
    for epoch in range(10):
        # Unfreeze entire model after 3 epochs
        if epoch == 3:
            for param in model.raft.parameters():
                param.requires_grad = True
            optimizer = optim.Adam(model.parameters(), lr=1e-5)
            print("Unfroze entire model for full fine-tuning.")

        total_loss = 0
        for i, (img1, img2, flow_gt) in enumerate(dataloader):
            img1, img2, flow_gt = img1.to(DEVICE), img2.to(DEVICE), flow_gt.to(DEVICE)
            optimizer.zero_grad()
            flow_pred = model(img1, img2)

            loss = criterion_flow(flow_pred, flow_gt)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            print(f"Epoch {epoch+1}, Iteration {i+1}: Loss = {loss.item():.4f}")

        print(f"Epoch {epoch+1} Completed: Avg Loss = {total_loss/len(dataloader):.4f}")

    torch.save(model.state_dict(), '/Users/edasaruhan21/RAFT/raft_flow_only_finetuned_not_strach.pth')
