In [None]:
import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from argparse import Namespace
from torchvision import transforms

sys.path.append(os.path.join(os.path.dirname(__file__), 'core'))
from raft import RAFT
from utils.utils import InputPadder

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

def preprocess_image(path):
    img = Image.open(path).convert('RGB')
    return transforms.ToTensor()(img).unsqueeze(0).to(DEVICE)

def create_fixed_jet_mask(height, width, sigma=0.15):
    y = np.linspace(0, 1, height)
    attn = np.exp(-((y - 0.5)**2) / (2 * sigma**2))
    mask = np.tile(attn[:, None], (1, width))
    return torch.tensor(mask, dtype=torch.float32, device=DEVICE)

class RAFTWrapper(torch.nn.Module):
    def __init__(self, args):
        super().__init__()
        self.raft = RAFT(args)

    def forward(self, image1, image2):
        padder = InputPadder(image1.shape)
        image1, image2 = padder.pad(image1, image2)
        _, flow_up = self.raft(image1, image2, iters=12, test_mode=True)
        return flow_up

def save_flow_sequence_gif(flow_list, save_path):
    from PIL import Image
    images = [Image.fromarray((f * 255).astype(np.uint8)) for f in flow_list]
    images[0].save(save_path, save_all=True, append_images=images[1:], duration=150, loop=0)

if __name__ == "__main__":
    args = Namespace(
        small=False,
        mixed_precision=False,
        alternate_corr=False,
        model='/Users/edasaruhan21/RAFT/models/raft-sintel.pth'
    )

    model = RAFTWrapper(args).to(DEVICE)
    state_dict = torch.load(args.model, map_location=DEVICE)
    state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
    model.raft.load_state_dict(state_dict)
    model.eval()

    # Load all .jpg files
    test_folder = '/Users/edasaruhan21/22s/sample16'
    imgs = sorted([f for f in os.listdir(test_folder) if f.endswith('.png')])
    if len(imgs) < 2:
        raise RuntimeError("At least 2 sequential images required.")

    # Process sequence
    jet_attended_frames = []
    for i in range(len(imgs) - 1):
        img1 = preprocess_image(os.path.join(test_folder, imgs[i]))
        img2 = preprocess_image(os.path.join(test_folder, imgs[i + 1]))

        with torch.no_grad():
            flow = model(img1, img2)
            flow_mag = torch.norm(flow, dim=1).squeeze()  # [H, W]

            H, W = flow_mag.shape
            mask = create_fixed_jet_mask(H, W, sigma=0.15)
            attended = (flow_mag * mask).cpu().numpy()

            # Normalize and colormap
            attended_norm = (attended - attended.min()) / (attended.max() - attended.min() + 1e-6)
            jet_rgb = plt.cm.jet(attended_norm)[..., :3]
            jet_rgb = (jet_rgb * 255).astype(np.uint8)

            jet_attended_frames.append(jet_rgb)

    plt.figure(figsize=(10, 5))
    im = plt.imshow(flow_mag, cmap='jet', vmin=0, vmax=14) 
    plt.axis('off')
    plt.title("Optical Flow Magnitude (Jet)")
    
 
    cbar = plt.colorbar(im, fraction=0.046, pad=0.04)
    cbar.set_label('Flow Magnitude', rotation=270, labelpad=15)
    
    
    plt.tight_layout()
    flow3 = flow_mag
    plt.savefig('/Users/edasaruhan21/RAFT/124.png', bbox_inches='tight', pad_inches=0)
    x_index = 100
    if x_index < flow_mag.shape[1]:
        flow_column = flow_mag[:, x_index]
        np.savetxt('/Users/edasaruhan21/RAFT/new_melis_stage22_sample16_no_finetuned_100.txt', flow_column, fmt='%.6f')
    else:
        print(f"x=400 is out of bounds for image width {flow_mag.shape[1]}")
