In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import numpy as np
import cv2
from PIL import Image
import sys
from argparse import Namespace
from collections import OrderedDict
import glob
import random

import matplotlib.pyplot as plt
from argparse import Namespace
from collections import OrderedDict

sys.path.append('core')
from raft import RAFT
from utils.utils import InputPadder

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

class RAFTWrapper(torch.nn.Module):
    def __init__(self, args):
        super(RAFTWrapper, self).__init__()
        self.raft = RAFT(args)

    def forward(self, image1, image2):
        padder = InputPadder(image1.shape)
        image1, image2 = padder.pad(image1, image2)
        _, flow_up = self.raft(image1, image2, iters=12, test_mode=True)
        return flow_up

def preprocess(path):
    img = np.array(Image.open(path).convert('RGB'))
    img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0
    return img.unsqueeze(0)

def flow_to_image(flow):
    flow = flow[0].cpu().numpy()
    u, v = flow[0], flow[1]
    mag = np.sqrt(u ** 2 + v ** 2)
    ang = np.arctan2(v, u)
    hsv = np.zeros((u.shape[0], u.shape[1], 3), dtype=np.float32)
    hsv[..., 0] = (ang + np.pi) / (2 * np.pi)
    hsv[..., 1] = 1.0
    hsv[..., 2] = np.clip(mag / np.max(mag + 1e-6), 0, 1)
    rgb = plt.cm.hsv(hsv[..., 0])[:, :, :3] * hsv[..., 2][..., None]
    return (rgb * 255).astype(np.uint8)

if __name__ == '__main__':
    args = Namespace(
        small=False,
        mixed_precision=False,
        alternate_corr=False,
        model='/Users/edasaruhan21/RAFT/raft_flow_only_finetuned_not_strach.pth'
    )

    test_folder = '/Users/edasaruhan21/22s/sample16'
    test_imgs = sorted(glob.glob(os.path.join(test_folder, '*.png')))[:2]

    img1 = preprocess(test_imgs[0]).to(DEVICE)
    img2 = preprocess(test_imgs[1]).to(DEVICE)

    model = RAFTWrapper(args).to(DEVICE)
    state_dict = torch.load(args.model, map_location=DEVICE)
    new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
    model.load_state_dict(torch.load('/Users/edasaruhan21/RAFT/raft_flow_only_finetuned_not_strach.pth', map_location=DEVICE))
    model.eval()

    with torch.no_grad():
        flow = model(img1, img2)
        flow_img = flow_to_image(flow)

  
    flow_mag = torch.norm(flow, dim=1).squeeze().cpu().numpy()
    
    plt.figure(figsize=(10, 5))
    im = plt.imshow(flow_mag, cmap='jet', vmin=0, vmax=14) 
    plt.axis('off')
    plt.title("Optical Flow Magnitude (Jet)")
    
 
    cbar = plt.colorbar(im, fraction=0.046, pad=0.04)
    cbar.set_label('Flow Magnitude', rotation=270, labelpad=15)
    
    plt.tight_layout()
    flow3 = flow_mag

    plt.savefig('/Users/edasaruhan21/RAFT/124.png', bbox_inches='tight', pad_inches=0)
    x_index = 100
    if x_index < flow_mag.shape[1]:
        flow_column = flow_mag[:, x_index]
        np.savetxt('/Users/edasaruhan21/RAFT/neww_melis_stage22_sample16_finetuned_dns_100.txt', flow_column, fmt='%.6f')
    else:
        print(f"x=400 is out of bounds for image width {flow_mag.shape[1]}")
