In [None]:
import os
import cv2
from PIL import Image
import torch
from torchvision import transforms
from torchvision.models import MobileNet_V2_Weights, mobilenet_v2

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

img_size = 224
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

Using device: cuda


In [None]:
def load_model(model_path, num_classes):
    model = mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1)
    model.classifier[1] = torch.nn.Linear(model.last_channel, num_classes)
    
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint, strict=False)  
    model = model.to(device).eval()
    return model

In [None]:
def predict_video(video_path, model, class_names, frame_rate=5):
    cap = cv2.VideoCapture(video_path)
    predictions = []
    frame_count = 0

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        if frame_count % frame_rate == 0:
            image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
            input_tensor = transform(image).unsqueeze(0).to(device)

            with torch.no_grad():
                output = model(input_tensor)
                _, predicted = torch.max(output, 1)
                predictions.append(class_names[predicted.item()])

        frame_count += 1

    cap.release()

    final_prediction = max(set(predictions), key=predictions.count)
    return final_prediction

In [None]:
if __name__ == "__main__":
    video_model_path = "/home/mostafabakr/Desktop/Project X/models/asl_video_model.pth"
    video_path = "/home/mostafabakr/Desktop/Project X/Test_img/Screencast from 11-29-2024 04:04:33 AM.webm"

    video_class_names = ['J', 'Z']
    num_video_classes = len(video_class_names)

    video_model = load_model(video_model_path, num_video_classes)

    video_prediction = predict_video(video_path, video_model, video_class_names)
    print(f"Prediction for the video: {video_prediction}")

  checkpoint = torch.load(model_path, map_location=device)


Prediction for the video: J
