In [None]:
import db_builder.db_handler as dbh
import numpy as np
from torchvision import transforms, datasets
from tqdm import tqdm
import logging
from getpass import getpass
from PIL import Image
import torchvision.models as models
import torch.nn as nn
import torch
import os
import cv2

In [None]:
logging.getLogger('sqlalchemy').setLevel(logging.WARNING)
db_params = {
    'user': 'postgres',
    'password': getpass('Please enter DB pw'),  # enter your DB password
    'host': 'localhost',  # 'localhost' or IP address
    'port': '5432',  # '5432'
    'database': 'ttdatabase',  #tensionTerminator
}
toolcheck = dbh.DB_Conn(db_params)
toolcheck.connect()
engine = toolcheck.get_engine()

In [None]:
def trigger_crop(image):
    crop_box = (400, 450, 550, 550)
    cropped_image = transforms.functional.crop(image, *crop_box)
    return cropped_image


def video_to_image_converter(source_path: str, output_path: str, crop=False):
    vidcap = cv2.VideoCapture(source_path)
    os.makedirs(output_path, exist_ok=True)
    count = 0

    total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
    #fps = int(vidcap.get(cv2.CAP_PROP_FPS))
    fps = 30

    while count < total_frames:
        success, image = vidcap.read()

        if count % int(fps) == 0:
            if not image is None and not image.size == 0:
                transform_test = transforms.ToPILImage()
                image = transform_test(image)
                if crop:
                    image = trigger_crop(image)
                image = np.asarray(image)
                cv2.imwrite(f"{output_path}/{count}.png", image)  # save frame as PNG file

        count += 1

    vidcap.release()
    cv2.destroyAllWindows()
    
def checkImage(path: str, transfer_model, orig_set, transforms_wt):

    img = Image.open(path)
    img_tensor = transforms_wt(img).unsqueeze(0)
    prediction = transfer_model(img_tensor.to("cuda"))
    predicted_probabilities = torch.softmax(prediction, dim=1)
    predicted_class_idx = torch.argmax(prediction).item()
    predicted_class = orig_set.classes[predicted_class_idx]
    

    # Get the confidence score for the predicted class
    confidence = predicted_probabilities[0, predicted_class_idx].item() * 100  # Convert to percentage

    return predicted_class, confidence


transforms_wt = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.485, 0.456, 0.406],
        [0.229, 0.224, 0.225]
    )
])

In [None]:
directory = "tmp"
if not os.path.exists(directory):
    os.makedirs(directory)
    
video_source = toolcheck.get_filepath_by_loop_id(146)
video_to_image_converter(video_source, directory, True)


In [None]:
orig_set = datasets.ImageFolder(
    root='C:\\Users\\Pirmin.000\\PycharmProjects\\IGP\\data\\tool_finder\\10_11_2023_21_05_33',
    transform=transforms_wt
)

In [None]:
orig_set.classes

In [None]:
transfer_model = models.resnet152()
transfer_model.fc = nn.Sequential(
    nn.Linear(transfer_model.fc.in_features, 2048),
    nn.ReLU(),
    nn.Dropout(),
    nn.Linear(2048,1024),
    nn.ReLU(),
    nn.Dropout(),
    nn.Linear(1024,500),
    nn.ReLU(),
    nn.Dropout(),
    nn.Linear(500,2)
)
transfer_model_state_dict = torch.load("C:\\Users\\Pirmin.000\\PycharmProjects\\IGP\\models\\tool_finder\\10_11_2023_21_05_33\\model.pt")
transfer_model.load_state_dict(transfer_model_state_dict)
transfer_model.to("cuda")
transfer_model.eval()

In [None]:
dataset = []
conf = []
directory = "tmp"

for filename in tqdm(os.listdir(directory)):
    f = os.path.join(directory, filename)
    #print(f)
    predicted_class, confidence = checkImage(f, transfer_model, orig_set, transforms_wt)
    conf.append(confidence)
    dataset.append(predicted_class)

    os.remove(f)
os.rmdir(directory)
dataset_confidence = sum(conf)/len(conf)

In [None]:
dataset_confidence

In [None]:
from collections import Counter

In [None]:
test = Counter(dataset)

In [None]:
test

In [None]:
test[orig_set.classes[0]]

In [None]:
test[orig_set.classes[1]]

<video controls src=video_source />

In [None]:
video_source