In [1]:
import os
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.io import decode_image, decode_jpeg
import torch.optim.lr_scheduler as lr_scheduler
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision.models.optical_flow import raft_large
from torchvision.utils import flow_to_image
import numpy as np
import matplotlib.pyplot as plt
from torch import Tensor
import random
import time
import cv2

In [2]:
class MobileNetV2ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, groups=1, activation=True):
        super().__init__()
        padding = (kernel_size - 1) // 2
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride,padding, groups=groups, bias=False)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.997)
        self.activation = nn.ReLU6() if activation else None

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        if self.activation:
            x = self.activation(x)
        return x

class MobileNetV2InvertedResidual(nn.Module):
    def __init__(self, in_channels, out_channels, stride, expand_ratio):
        super().__init__()
        hidden_dim = in_channels * expand_ratio
        
        self.expand_1x1 = MobileNetV2ConvLayer(in_channels, hidden_dim, 1, activation=True)
        self.conv_3x3 = MobileNetV2ConvLayer(hidden_dim, hidden_dim, 3, stride=stride, groups=hidden_dim, activation=True)                          
        self.reduce_1x1 = MobileNetV2ConvLayer(hidden_dim, out_channels, 1, activation=False)

    def forward(self, x):
        identity = x
        
        x = self.expand_1x1(x)
        x = self.conv_3x3(x)
        x = self.reduce_1x1(x)
        
        if x.shape == identity.shape:
            x = x + identity
        return x
        

class Backbone(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_stem = nn.Sequential(
            MobileNetV2ConvLayer(3, 32, 3, stride=2),
            MobileNetV2ConvLayer(32, 32, 3, groups=32),
            MobileNetV2ConvLayer(32, 16, 1, activation=False)
        )

        self.block1 = MobileNetV2InvertedResidual(16, 24, stride=2, expand_ratio=6)
        self.block2 = MobileNetV2InvertedResidual(24, 24, stride=1, expand_ratio=6)
        self.block3 = MobileNetV2InvertedResidual(24, 32, stride=2, expand_ratio=6)
       
        
    def forward(self, x):
        out1 = self.conv_stem(x)
        x = self.block1(out1)
        out2 = self.block2(x)
        out3 = self.block3(out2)
        return out1, out2, out3

class CombinedEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = Backbone()

        self.image_conv = nn.Conv2d(3, 8, kernel_size=3, padding=1)
        self.batch_norm = nn.BatchNorm2d(8)
        self.activation = nn.ReLU(inplace=True)

    def forward(self, image):
        out2, out3, out4 = self.backbone(image)

        x = self.image_conv(image)
        x = self.batch_norm(x)
        out1 = self.activation(x)

        return out1, out2, out3, out4

class FlowEncoderResidualBlock(nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels):
        super(FlowEncoderResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, hidden_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(hidden_channels, output_channels, kernel_size=3, padding=1)
        self.batch_norm1 = nn.BatchNorm2d(hidden_channels)
        self.batch_norm2 = nn.BatchNorm2d(output_channels)
        self.residual_conv = nn.Conv2d(input_channels, output_channels, kernel_size=1, bias=False)
        self.activation = nn.ReLU(inplace=True)
        
    def forward(self, x):
        residual = self.residual_conv(x)
        
        x = self.conv1(x)
        x = self.batch_norm1(x)
        x = self.activation(x)
        x = self.conv2(x)
        x = self.batch_norm2(x)
        
        x = x + residual
        x = self.activation(x)
        
        return x

class FlowEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.block1 = FlowEncoderResidualBlock(8 * 2, 32, 16)
        self.block2 = FlowEncoderResidualBlock(16 * 2 + 16, 32, 16)
        self.block3 = FlowEncoderResidualBlock(24 * 2 + 16, 64, 32)
        self.block4 = FlowEncoderResidualBlock(32 * 2 + 32, 128, 32)

        self.pooling = nn.AvgPool2d(2)
        
    def forward(self, encoder_A_output, encoder_B_output):
        A1, A2, A3, A4 = encoder_A_output
        B1, B2, B3, B4 = encoder_B_output

        x = torch.cat([A1, B1], axis = 1)
        block1_output = self.block1(x)
        x = self.pooling(block1_output)

        x = torch.cat([A2, B2, x], axis = 1)
        block2_output = self.block2(x)
        x = self.pooling(block2_output)

        x = torch.cat([A3, B3, x], axis = 1)
        block3_output = self.block3(x)
        x = self.pooling(block3_output)

        x = torch.cat([A4, B4, x], axis = 1)
        block4_output = self.block4(x)
        
        return block1_output, block2_output, block3_output, block4_output

class FlowRefiner(nn.Module):
    def __init__(self):
        super(FlowRefiner, self).__init__()
        
        self.consistency_residual1 = FlowEncoderResidualBlock(32, 64, 32)
        self.consistency_residual2 = FlowEncoderResidualBlock(32 + 32, 64, 32)
        
        self.internal_consistency_residual = FlowEncoderResidualBlock(32, 64, 32)
        
        self.pooling = nn.AvgPool2d(2)
        
    def forward(self, input_flow):
        
        residual_connection = self.consistency_residual1(input_flow)
        
        x = self.pooling(residual_connection)
        x = self.internal_consistency_residual(x)
        
        x = F.interpolate(
            x, 
            size=residual_connection.shape[2:],
            mode='bilinear', 
            align_corners=False
        )
        
        x = torch.cat([x, input_flow], axis = 1)
        x = self.consistency_residual2(x)
        
        return x

class UpscaleBlock(nn.Module):
    def __init__(self, base_channels, details_channels, out_channels):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)

        self.block = FlowEncoderResidualBlock(base_channels + details_channels, out_channels * 2, out_channels)

    def forward(self, base, details):
        upsampled_base = self.upsample(base)

        x = torch.cat([upsampled_base, details], axis = 1)
        x = self.block(x)

        return x

class FlowDecoder(nn.Module):  
    def __init__(self):
        super().__init__()
        self.upscale_block1 = UpscaleBlock(32, 32, 32)
        self.upscale_block2 = UpscaleBlock(32, 16, 32)
        self.upscale_block3 = UpscaleBlock(32, 16, 16)

        self.linear = nn.Conv2d(16, 2, kernel_size=1)
        
    def forward(self, flows):
        f1, f2, f3, f4 = flows
        
        x = self.upscale_block1(f4, f3)
        x = self.upscale_block2(x, f2)
        x = self.upscale_block3(x, f1)

        x = self.linear(x)

        return x



class OpticalFlowNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = CombinedEncoder()
        self.flow_encoder = FlowEncoder()
        self.flow_decoder = FlowDecoder()
        self.flow_refiner = FlowRefiner()

    def forward(self, frameA, frameB):
        encoder_A_output = self.encoder(frameA)
        encoder_B_output = self.encoder(frameB)
        flows = self.flow_encoder(encoder_A_output, encoder_B_output)
        f1, f2, f3, f4 = flows
        f4 = self.flow_refiner(f4)
        flows = f1, f2, f3, f4
        flow = self.flow_decoder(flows)
        
        return flow

In [18]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = OpticalFlowNetwork()
model.load_state_dict(torch.load("micron-flow-beta-0.1.pth", weights_only = True))
model = model.to(device).eval()

In [19]:
# Open video capture
image_size = (152, 240)

cap = cv2.VideoCapture(0)
if not cap.isOpened():
    raise IOError("Cannot open webcam")

prev_frame = None
prev_output = None

with torch.no_grad():
    while True:
        # Capture frame
    
        time.sleep(0.01)
        ret, frame = cap.read()
        
        if not ret:
            break

        # Preprocess frame
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        resized_frame = cv2.resize(frame_rgb, (image_size[1], image_size[0]), interpolation = cv2.INTER_AREA)
        tensor_frame = torch.from_numpy(resized_frame).permute(2, 0, 1).float() / 255.0
        tensor_frame = tensor_frame.unsqueeze(0).to(device)  # Add batch dimension

        if prev_frame is not None:
            cur = time.time()
            if torch.equal(tensor_frame, prev_frame):
                continue
            flow = model(tensor_frame, prev_frame)
            print(time.time() - cur)
            #void_channel = torch.ones([1, image_size[0], image_size[1]]).to(device) * 128
            #flow_img = torch.clip(torch.cat([(flow[0] * 255), void_channel]), 0, 255).permute(1, 2, 0).cpu().numpy()
            flow = flow[0]
            flow[0, 0, 0] = -1
            flow[0, 0, 1] = 1
            flow_img = flow_to_image(flow).permute(1, 2, 0).cpu().numpy()
            if prev_output is not None:
                smoothness = 0.0
                prev_output = prev_output * smoothness + flow_img * (1.0 - smoothness)
                
            else:
                prev_output = flow_img
                
            flow_img = prev_output.astype(np.uint8)
            
            flow_display = cv2.resize(flow_img, (960*2, 520*2))
            
            # Display result
            cv2.imshow('Optical Flow', flow_display)

        prev_frame = tensor_frame.clone()

        # Exit on 'q' press
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

cap.release()
cv2.destroyAllWindows()

0.02394556999206543
0.012676715850830078
0.012248992919921875
0.012222766876220703
0.012387275695800781
0.012253761291503906
0.012729406356811523
0.012178897857666016
0.012613534927368164
0.012687206268310547
0.014333486557006836
0.012476921081542969
0.011945962905883789
0.012471675872802734
0.012022733688354492
0.012076854705810547
0.011684656143188477
0.01169133186340332
0.014392614364624023
0.012199878692626953
0.012138605117797852
0.011540412902832031
0.01169729232788086
0.012212514877319336
0.011911392211914062
0.013893842697143555
0.013659000396728516
0.012888431549072266
0.011305093765258789
0.013258934020996094
0.011907577514648438
0.012003183364868164
0.012407541275024414
0.011663675308227539
0.011890649795532227
0.011796712875366211
0.011845111846923828
0.012089014053344727


0.014444351196289062
0.012444734573364258
0.018328189849853516
0.009999752044677734
0.010007619857788086
0.010994434356689453
0.01298975944519043
0.010986804962158203
0.0165402889251709
0.009997367858886719
0.012990236282348633
0.012992620468139648
0.009996891021728516
0.015995264053344727
0.015016555786132812
0.03302454948425293
0.017011404037475586
0.011003732681274414
0.010000228881835938
0.010406255722045898
0.02199864387512207
0.010998964309692383
0.011537790298461914
0.031011104583740234
0.016997575759887695
0.012998580932617188
0.011998653411865234
0.014997005462646484
0.015535116195678711
0.012993574142456055
0.010999679565429688
0.01100611686706543
0.009999990463256836
0.010999679565429688
0.012991189956665039
0.009994983673095703
0.016994953155517578
0.013997793197631836
0.011998176574707031
0.0110015869140625
0.011013984680175781
0.0319972038269043
0.019998550415039062
0.0149993896484375
0.011995553970336914
0.009999275207519531
0.01598072052001953
0.013998031616210938
0.011