In [5]:
import cv2
import numpy as np
import os
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet18
from sklearn.preprocessing import LabelEncoder

# Step 1: Extract Optical Flow from Videos
def extract_optical_flow(video_path, save_dir):
    cap = cv2.VideoCapture(video_path)
    ret, prev_frame = cap.read()
    prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    frame_idx = 0
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        flow = cv2.calcOpticalFlowFarneback(prev_gray, gray, None, 0.5, 3, 15, 3, 5, 1.2, 0)

        # Convert flow to RGB representation
        hsv = np.zeros_like(frame)
        hsv[..., 1] = 255
        mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1])
        hsv[..., 0] = ang * 180 / np.pi / 2
        hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
        flow_rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)

        # Save the optical flow frame
        cv2.imwrite(f"{save_dir}/flow_{frame_idx}.png", flow_rgb)

        prev_gray = gray.copy()
        frame_idx += 1

    cap.release()

# Step 2: Process All Videos in a Folder
def process_videos(video_folder, output_folder):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    for video_name in os.listdir(video_folder):
        video_path = os.path.join(video_folder, video_name)
        if video_path.endswith(".mp4"):
            # Create a subfolder for each video's optical flow frames
            video_output_folder = os.path.join(output_folder, os.path.splitext(video_name)[0])
            extract_optical_flow(video_path, video_output_folder)

# Step 3: Define the Dataset Class
class OpticalFlowDataset(Dataset):
    def __init__(self, data_dir, transform=None, label_encoder=None):
        self.data_dir = data_dir
        self.transform = transform
        self.label_encoder = label_encoder
        self.image_paths = []
        self.labels = []

        # Load image paths and labels
        for label in os.listdir(data_dir):
            label_dir = os.path.join(data_dir, label)
            if os.path.isdir(label_dir):
                for video_folder in os.listdir(label_dir):
                    video_folder_path = os.path.join(label_dir, video_folder)
                    if os.path.isdir(video_folder_path):
                        for image_name in os.listdir(video_folder_path):
                            self.image_paths.append(os.path.join(video_folder_path, image_name))
                            self.labels.append(label)

        # Encode labels if a label encoder is provided
        if self.label_encoder:
            self.labels = self.label_encoder.transform(self.labels)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = cv2.imread(self.image_paths[idx])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        if self.transform:
            img = self.transform(img)
        label = self.labels[idx]
        return img, label

# Step 4: Train the Model
def train_model(dataloader, model, criterion, optimizer, num_epochs, device):
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader):.4f}")

    print("Training complete.")



# Main Script
if __name__ == "__main__":
    # Step 1: Extract Optical Flow from Videos
    video_folder = "strike_videos"  # Folder containing your strike videos
    output_folder = "optical_flow_frames"  # Folder to save optical flow frames
    process_videos(video_folder, output_folder)

    # Step 2: Prepare Dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((224, 224)),
    ])

    # Encode labels
    labels = ["fastball", "curveball", "slider"]  # Replace with your actual labels
    label_encoder = LabelEncoder()
    label_encoder.fit(labels)

    dataset = OpticalFlowDataset(data_dir=output_folder, transform=transform, label_encoder=label_encoder)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

    # Step 3: Define Model
    num_classes = len(label_encoder.classes_)
    model = resnet18(pretrained=True)
    model.fc = torch.nn.Linear(model.fc.in_features, out_features=num_classes)

    # Move model to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    # Step 4: Train the Model
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    num_epochs = 10

    train_model(dataloader, model, criterion, optimizer, num_epochs, device)