In [1]:
import sys
import os
import numpy as np

# Add the absolute path to the zfa directory
zfa_path = '/Users/alyn/Desktop/zebrafish_agent/'
if zfa_path not in sys.path:
    sys.path.append(zfa_path)



In [2]:
class Args:
    def __init__(self):
        self.size = 64
        self.gpus = 0
        self.checkpoint_path = 'checkpoints'
        self.lr = 1e-3
        self.wdecay = 1e-5
        self.epsilon = 1e-8
        self.num_epochs = 1
        self.batch_size = 1
        self.val_freq = 1
        self.clip = 1.0
        self.mixed_precision = False
        self.add_noise = True
        self.save_checkpoint = False
        self.name = 'test1'

In [3]:
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image
from zfa.model_training.augmentor import FlowAugmentorOpticFlow, Transform

class OpticFlowDataset(Dataset):
    def __init__(self, 
                 split='train',
                 aug_params=None,
                 fake_data=True,  # New parameter to switch to fake data
                 num_samples=10,  # Number of fake samples to generate
                 img_size=(64, 64)):  # Size of the generated images
        self.augmentor = None
        self.fake_data = fake_data
        self.num_samples = num_samples
        self.img_size = img_size

        if aug_params:
            self.augmentor = FlowAugmentorOpticFlow(**aug_params)

        # Generate fake data or initialize normally
        if self.fake_data:
            self.generate_fake_data()
        else:
            self.frame_pairs = []
            self.flow_files = []
            self.match_files()

    def generate_fake_data(self):
        """Generate fake image pairs and flow fields."""
        self.frame_pairs = [("frame1", "frame2") for _ in range(self.num_samples)]
        self.flow_files = ["flow"] * self.num_samples

    def read_flo_file(self, fn):
        """Generate a fake flow field."""
        h, w = self.img_size
        return np.random.randn(h, w, 2).astype(np.float32)  # Random flow field

    def __len__(self):
        return len(self.flow_files)

    def __getitem__(self, idx):
        # Generate random RGB images
        frame1 = np.random.randint(0, 256, (*self.img_size, 3), dtype=np.uint8)
        frame2 = np.random.randint(0, 256, (*self.img_size, 3), dtype=np.uint8)
        flow = self.read_flo_file(self.flow_files[idx])

        # Convert to PIL images
        frame1 = Image.fromarray(frame1)
        frame2 = Image.fromarray(frame2)

        if self.augmentor:
            frame1, frame2, flow = self.augmentor(frame1, frame2, flow)

        

        # Convert to torch tensors
        frame1 = torch.from_numpy(np.array(frame1)).permute(2, 0, 1).float()  # [C, H, W]
        frame2 = torch.from_numpy(np.array(frame2)).permute(2, 0, 1).float()  # [C, H, W]
        flow = torch.from_numpy(flow).permute(2, 0, 1).float()  # [2, H, W]

        return frame1, frame2, flow

In [4]:
# Create a DataLoader using the fake dataset
from torch.utils.data import DataLoader


aug_params = {
    'crop_size': (32, 32),  # Cropping to 32x32
    'min_scale': -0.2,
    'max_scale': 0.5,
    'do_flip': True,
    'size': 64   
    }


transform = Transform(size=64)
# Initialize the dataset with fake data
dataset_train = OpticFlowDataset(fake_data=True, num_samples=5, img_size=(480, 853), aug_params=aug_params)
dataset_val = OpticFlowDataset(fake_data=True, num_samples=5, img_size=(480, 853), transform=transform)

# Create a DataLoader
dataloader_train = DataLoader(dataset_train, batch_size=2, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=2, shuffle=False)

# Iterate through the DataLoader
for batch_idx, (frame1, frame2, flow) in enumerate(dataloader_train):
    print(f"TRAIN Batch {batch_idx}:")
    print(f"  Frame 1 Shape: {frame1.shape}")
    print(f"  Frame 2 Shape: {frame2.shape}")
    print(f"  Flow Shape: {flow.shape}")
for batch_idx, (frame1, frame2, flow) in enumerate(dataloader_val):
    print(f"VAL Batch {batch_idx}:")
    print(f"  Frame 1 Shape: {frame1.shape}")
    print(f"  Frame 2 Shape: {frame2.shape}")
    print(f"  Flow Shape: {flow.shape}")

TRAIN Batch 0:
  Frame 1 Shape: torch.Size([2, 3, 64, 64])
  Frame 2 Shape: torch.Size([2, 3, 64, 64])
  Flow Shape: torch.Size([2, 2, 64, 64])
TRAIN Batch 1:
  Frame 1 Shape: torch.Size([2, 3, 64, 64])
  Frame 2 Shape: torch.Size([2, 3, 64, 64])
  Flow Shape: torch.Size([2, 2, 64, 64])
TRAIN Batch 2:
  Frame 1 Shape: torch.Size([1, 3, 64, 64])
  Frame 2 Shape: torch.Size([1, 3, 64, 64])
  Flow Shape: torch.Size([1, 2, 64, 64])
VAL Batch 0:
  Frame 1 Shape: torch.Size([2, 3, 64, 64])
  Frame 2 Shape: torch.Size([2, 3, 64, 64])
  Flow Shape: torch.Size([2, 2, 64, 64])
VAL Batch 1:
  Frame 1 Shape: torch.Size([2, 3, 64, 64])
  Frame 2 Shape: torch.Size([2, 3, 64, 64])
  Flow Shape: torch.Size([2, 2, 64, 64])
VAL Batch 2:
  Frame 1 Shape: torch.Size([1, 3, 64, 64])
  Frame 2 Shape: torch.Size([1, 3, 64, 64])
  Flow Shape: torch.Size([1, 2, 64, 64])


In [5]:
from zfa.models.optic_flow_architetures import Net_3_layers
from zfa.model_training.evaluate import validate_model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net_3_layers(64).to(device)

results = validate_model(model, dataloader_val, device)
print(f"Validation Results: {results}")

flow_pr.size(): torch.Size([2, 2, 64, 64])
flow_gt.size(): torch.Size([2, 2, 64, 64])
flow_pr.size(): torch.Size([2, 2, 64, 64])
flow_gt.size(): torch.Size([2, 2, 64, 64])
flow_pr.size(): torch.Size([1, 2, 64, 64])
flow_gt.size(): torch.Size([1, 2, 64, 64])
Validation EPE: 0.323702
Validation Results: {'epe': 0.3237015}
