In [1]:
# so this is a weird one, want to create a wrapper that can ID videos coming in from Twitch, Youtube, or on a monitor of choice and process them. Lets try selenium?
from selenium_helpers import buildHeadless, getBy
from PIL import Image
import io
import numpy as np
import time
from typing import List, Dict, Optional
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, FasterRCNN_ResNet50_FPN_Weights
import json
import torch
import numpy as np
import cv2

In [2]:
class TwitchStreamer():
    """
    Given a twitch streamer, navigate to their channel and start watching the stream
    """
    def __init__(self, username):
        # build a driver
        self.driver = buildHeadless()
        # navigate to the page
        self.driver.get(f'https://www.twitch.tv/{username}')
        # find our element to keep an eye on
        self.element = getBy(self.driver, 'tag', 'video')
        # wait a bit for the mute thing to unhide
        time.sleep(3)

    def get_img(self):
        location = self.element.location
        size = self.element.size

        data = self.driver.get_screenshot_as_png()
        im = Image.open(io.BytesIO(data))
        
        x = location['x']
        y = location['y']
        w = size['width']
        h = size['height']
        width = x + w
        height = y + h

        im = im.crop((int(x), int(y), int(width), int(height)))
        return im

In [3]:
class TwitchPredict():

    def __init__(self, username):
        """
        SO this function just needs to read in and init the model. Lord help me
        """
        # let's read in the JSON
        with open(f'run_params.json', 'r') as json_file:
            self.json = json.load(json_file)

        num_classes = len(self.json['classes'])
        self.classes = self.json['classes']

        # init the device
        self.DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        
        # load Faster RCNN pre-trained model
        model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
        # get the number of input features 
        in_features = model.roi_heads.box_predictor.cls_score.in_features
        # define a new head for the detector with required number of classes
        model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) 

        checkpoint = torch.load(f'model.pth', map_location=self.DEVICE)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
        model.to(self.DEVICE)

        # and assign it for use later
        self.model = model

        # let's also make our streamer object
        self.streamer = TwitchStreamer(username)

    def predict(self):
        print('predicting')
        results = []
            
        # lets read in the image
        og_image = cv2.cvtColor(np.array(self.streamer.get_img()), cv2.COLOR_RGB2BGR)
    
        # BGR to RGB
        image = cv2.cvtColor(og_image, cv2.COLOR_BGR2RGB).astype(np.float32)
        # make the pixel range between 0 and 1
        image /= 255.0
        # bring color channels to front
        image = np.transpose(image, (2, 0, 1)).astype(np.float32)
        # convert to tensor
        image = torch.tensor(image, dtype=torch.float).cuda()
        # add batch dimension
        image = torch.unsqueeze(image, 0)
        with torch.no_grad():
            outputs = self.model(image.to(self.DEVICE))
        # load all detection to CPU for further operations
        outputs = [{k: v.to('cpu') for k, v in t.items()} for t in outputs]
            
        # carry further only if there are detected boxes
        if len(outputs[0]['boxes']) != 0:
            boxes = outputs[0]['boxes'].data.numpy()
            scores = outputs[0]['scores'].data.numpy()
            # filter out boxes according to `detection_threshold`
            boxes = boxes[scores >= .9].astype(np.int32)
            draw_boxes = boxes.copy()
            # get all the predicited class names
            pred_classes = [self.classes[i] for i in outputs[0]['labels'].cpu().numpy()]
            img_height, img_width, channels = og_image.shape
            for j, box in enumerate(boxes):
                x, y, xmax, ymax = box[:4]
                single_result = {
                    'from_name': "label",
                    'to_name': "image",
                    'type': 'rectanglelabels',
                    'value': {
                        'rectanglelabels': [pred_classes[j]],
                        'x': float(x) / img_width * 100,
                        'y': float(y) / img_height * 100,
                        'width': (float(xmax) - float(x)) / img_width * 100,
                        'height': (float(ymax) - float(y)) / img_height * 100,
                    },
                    'score': round(float(scores[j]))
                }
                results.append(single_result)
        return [{
            'result': results,
            'score': .6
        }]

In [4]:
predictor = TwitchPredict('n3zmodgod')

In [5]:
%%time
# check things out
for i in range(100):
    predictor.predict()

predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting
predicting