In [1]:
import matplotlib.pyplot as plt
import torch
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

import glob
import cv2 as cv
import numpy as np
from pprint import pprint
import pickle

In [2]:
NUM_EPOCHS = 35
DATASET_PATH = os.path.join(
    os.getcwd(),
    "..\\..\\datasets\\unzipped\\endovis\\Segmentation_Robotic_Training\\Training"
)

DATASET_PATH

'c:\\Users\\thegr\\python\\school\\cv\\ecen-644-notebook\\final_project\\..\\..\\datasets\\unzipped\\endovis\\Segmentation_Robotic_Training\\Training'

In [3]:
input_videos = glob.glob(DATASET_PATH + "\\**\\*Video*.avi")
output_videos = glob.glob(DATASET_PATH + "\\**\\*Segmentation*.avi")
pprint(input_videos)
pprint(output_videos)

[]
[]


In [4]:
input_frames: list[np.ndarray] = []
output_frames: list[np.ndarray] = []
input_frames_raw_pkl_path = os.path.join(os.getcwd(), "input_frames_raw.pkl")
output_frames_raw_pkl_path = os.path.join(os.getcwd(), "output_frames_raw.pkl")

if os.path.exists(input_frames_raw_pkl_path) and os.path.exists(output_frames_raw_pkl_path):
    with open(input_frames_raw_pkl_path, "rb") as fp:
        input_frames = pickle.load(fp)

    with open(output_frames_raw_pkl_path, "rb") as fp:
        output_frames = pickle.load(fp)
else:
    # TODO: support vid w/ two instruments...
    for vids, container in zip((input_videos, output_videos), (input_frames, output_frames)):
        for vid in vids:
            print(f"Processing {vid}")
            cap = cv.VideoCapture(vid)
            frames = 0
            while True:
                frames += 1
                ret, frame = cap.read()
                if frame is None:
                    print(f"End of vid, total frames: {frames}")
                    break

                # For now, just append the raw frame. Additional processing can be carried out later.
                container.append(frame)

    with open(input_frames_raw_pkl_path, "wb") as fp:
        pickle.dump(input_frames, fp)

    with open(output_frames_raw_pkl_path, "wb") as fp:
        pickle.dump(output_frames, fp)

In [9]:
IMAGE_REDUCTION_SIZE = 5

In [10]:
input_frames_preprocessed = torch.zeros(
    [len(input_frames), input_frames[0].shape[2], input_frames[0].shape[0]//IMAGE_REDUCTION_SIZE, input_frames[0].shape[1]//IMAGE_REDUCTION_SIZE]
).to(device)

In [11]:
for idx in range(input_frames_preprocessed.shape[0]):
    input_frame_small = torch.from_numpy(cv.resize(
        input_frames[idx] / 255.0, # normalize image 
        (input_frames[idx].shape[1] // IMAGE_REDUCTION_SIZE, input_frames[idx].shape[0] // IMAGE_REDUCTION_SIZE)
    ))
    
    input_frames_preprocessed[idx, 0, :, :] = input_frame_small[:, :, 0]
    input_frames_preprocessed[idx, 1, :, :] = input_frame_small[:, :, 1]
    input_frames_preprocessed[idx, 2, :, :] = input_frame_small[:, :, 2]

In [12]:
output_frames_preprocessed = torch.zeros(
    [len(output_frames), 3, input_frames[0].shape[0]//IMAGE_REDUCTION_SIZE, input_frames[0].shape[1]//IMAGE_REDUCTION_SIZE]
).to(device)

In [13]:
for idx in range(output_frames_preprocessed.shape[0]):
    output_frame_small = cv.resize(output_frames[idx], (output_frames[idx].shape[1] // IMAGE_REDUCTION_SIZE, output_frames[idx].shape[0] // IMAGE_REDUCTION_SIZE))
    eef = torch.from_numpy((cv.inRange(cv.cvtColor(output_frame_small, code = cv.COLOR_BGR2GRAY), 65, 75) / 255.0).astype(np.bool_))
    shaft = torch.from_numpy((cv.inRange(cv.cvtColor(output_frame_small, code = cv.COLOR_BGR2GRAY), 155, 165) / 255.0).astype(np.bool_))
    background = ~torch.bitwise_or(eef, shaft)
    
    output_frames_preprocessed[idx, :, :, :] = torch.stack((background.int(), eef.int(), shaft.int()), dim = 0)

In [14]:
import torchvision
nnet = torchvision.models.segmentation.deeplabv3_resnet50(
    num_classes = 3
) 

In [15]:
nnet = nnet.to(device)

In [16]:
from torch.utils.data import DataLoader, TensorDataset
dataset = TensorDataset(input_frames_preprocessed, output_frames_preprocessed)
loader = DataLoader(
    dataset, 
    shuffle = True,
    batch_size = 64
)

In [17]:
optimizer = torch.optim.Adam(params=nnet.parameters(), lr = 0.01)
cost = torch.nn.BCEWithLogitsLoss()

In [18]:
losses = []

In [19]:
for epoch in range(NUM_EPOCHS):
    for (x, y) in loader:
        out = nnet(x)['out']
        nnet.zero_grad()
        loss = cost(out, y)
        loss.backward()
        print(f"Loss: {loss.data.cpu()}")
        losses.append(loss.data.cpu())
        optimizer.step()
    print(f"Epoch: {epoch+1}/{NUM_EPOCHS}")

Loss: 0.6179870367050171
Loss: 0.4489821493625641
Loss: 0.5605427622795105
Loss: 0.22922766208648682
Loss: 0.1929745376110077
Loss: 0.15276920795440674
Loss: 0.1550777554512024
Loss: 0.15074682235717773
Loss: 0.14032474160194397
Loss: 0.13883720338344574
Loss: 0.12094806879758835
Loss: 0.11631610989570618
Loss: 0.10521328449249268
Loss: 0.10256200283765793


KeyboardInterrupt: 

In [None]:
import random
index_ = random.randint(0, len(input_frames)-1)

input_ = torch.Tensor(
            input_frames_preprocessed[index_]
        ).reshape(
            (1, 
            input_frames_preprocessed[index_].shape[0], 
            input_frames_preprocessed[index_].shape[1], 
            input_frames_preprocessed[index_].shape[2]
            )
        )
        
output_ = output_frames_preprocessed[index_]

with torch.no_grad():
    nnet.eval()
    pred = nnet(input_)['out']
    seg = torch.argmax(pred[0], 0, keepdim = True).cpu().detach().numpy()  # Get prediction classes
    
%matplotlib qt
plt.figure()
plt.subplot(1, 2, 1)
plt.imshow(output_.data.cpu()[0,:,:])
plt.colorbar()
plt.subplot(1, 2, 2)
plt.imshow(seg[0,:,:])
plt.colorbar()

<matplotlib.colorbar.Colorbar at 0x25cc6af43d0>

In [None]:
def to_n_channel_pred(tensor_in: torch.Tensor):
    tensor_out = torch.zeros(
        size = (tensor_in.shape[1], tensor_in.shape[2]), 
        dtype = torch.uint8
    )
    
    for n in range(1, tensor_in.shape[0]):
        tensor_out += (tensor_in[n, :, :] * n).int()
    
    return tensor_out

In [None]:
import torchmetrics
torchmetrics.functional.dice(
    torch.argmax(pred[0], 0, keepdim = True).cpu(), 
    to_n_channel_pred(output_.data.cpu()), 
    ignore_index = 0
)

tensor(0.9287)

In [None]:
plt.figure()
plt.plot(losses)

[<matplotlib.lines.Line2D at 0x25cc6b34e50>]

In [None]:
nnet.state_dict()

OrderedDict([('backbone.conv1.weight',
              tensor([[[[-2.6472e-01, -1.9102e-01, -3.4247e-01,  ..., -2.2264e-01,
                         -8.7132e-02,  6.8446e-03],
                        [ 4.9499e-02,  3.9041e-02, -6.8383e-02,  ..., -1.6765e-01,
                         -9.6502e-02, -2.2489e-01],
                        [-7.0923e-03, -5.5040e-02, -1.9433e-02,  ..., -1.2695e-01,
                         -2.1827e-03, -4.5803e-02],
                        ...,
                        [ 4.5483e-01,  8.9415e-02,  9.8211e-02,  ..., -1.9646e-01,
                         -1.1124e-01, -9.2523e-02],
                        [ 2.9170e-01, -8.5915e-02,  1.4855e-01,  ...,  1.2142e-01,
                          4.9973e-02, -5.4675e-02],
                        [-8.2907e-02, -5.7504e-02, -1.9516e-01,  ..., -1.0283e-01,
                          2.2014e-03,  1.5401e-02]],
              
                       [[-4.2029e-01, -2.5523e-01, -2.9862e-01,  ..., -1.4328e-01,
                       