In [28]:
from __future__ import print_function, division
import os
import torch
# import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import glob
from torchvision.io import read_video
from torch import from_numpy
import skimage
import cv2
from skimage import metrics as sm
import torchvision.transforms.functional as F

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

In [9]:
class MRIDataset(Dataset):
    """MRI dataset."""

    def __init__(self, root_dir, transform=None):
        """
        Arguments:
            root_dir (string): Directory with all the videos.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.paths = self.load_paths(root_dir)
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        # if torch.is_tensor(idx):
        #     idx = idx.tolist()

        vid_name = self.paths[idx]
        sample = self.load_image_batches(vid_name, 1, 24)

        if self.transform:
            sample = self.transform(sample)

        return sample
    



    # function to extract the paths for files froma path
    def load_paths(self, data_path):
        files = []
        files.append(glob.glob(data_path, 
                    recursive = True))
        return files[0]
    




    def load_image_batches(self, path, step = 10, smooth = 5):
        # path = "C:/Users/Misha/OneDrive - University of Sussex/FYP/Participants/Participant_12/Processed_data/Video/Subject_12_03.mp4"

        frames, _, _ = read_video(str(path), output_format="TCHW") # returns video frames, audio frames, metadata for the video and audio
        # print(frames.shape)
        images1 = []
        images2 = []
        img = np.array([np.moveaxis(frames[a].numpy()[:,:,:], 0, -1) for a in range(len(frames))])
        # print(img.shape)
        # print(sm.structural_similarity(img[0, 200:500, 600:900, 0], img[1, 200:500, 600:900, 0]))
        # len(frames)-2
        for i in range(0, len(frames)-2, step):
            img1 = img[i, 100:600, 500:1000, :]
            img2 = img[i+1, 100:600, 500:1000, :]
            if (sm.structural_similarity(img1[:,:,0], img2[:,:,0]) < 0.90):

                new_shape = (img1.shape[0] , img1.shape[1] , img1.shape[2])
                blurred1 = skimage.transform.resize(image=img1, output_shape=new_shape).astype(np.float32)
                blurred2 = skimage.transform.resize(image=img2, output_shape=new_shape).astype(np.float32)

                blurred_1 = np.moveaxis(cv2.bilateralFilter(blurred1,smooth,160,160), -1, 0)
                blurred_2 = np.moveaxis(cv2.bilateralFilter(blurred2,smooth,160,160), -1, 0)

                images1.append(torch.from_numpy(blurred_1))
                images2.append(torch.from_numpy(blurred_2))

        if len(images2) < 1: #when the video is too short/ redundant video
            img1_batch = images1
            img2_batch = images2
        else:
            img1_batch = torch.stack(images1) # making predictions between 2 pairs of frames 53 and 83, and 84 and 130
            img2_batch = torch.stack(images2)

        return img1_batch, img2_batch

In [37]:
class Rescale(object):
    def __init__(self, output_size) -> None:
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        print(type(sample))
        print(sample[0].shape)
        img1_batch = F.resize(sample[0], size=[self.output_size, self.output_size], antialias=False) # resize frames to ensure they are divisable by 8 
        img2_batch = F.resize(sample[1], size=[self.output_size, self.output_size], antialias=False)

        return (img1_batch, img2_batch)

In [27]:

from torchvision.models.optical_flow import Raft_Large_Weights
from torchvision.models.optical_flow import raft_large

In [46]:
class RaftTransforms(object):
    def __call__(self, sample):
        weights = Raft_Large_Weights.DEFAULT
        raft_transforms = weights.transforms()
        return raft_transforms(sample[0], sample[1])

In [10]:
video_dataset = MRIDataset("/Users/men22/OneDrive - University of Sussex/FYP/Participants/VIDEOS_ALL/*.mp4")

In [38]:
type(video_dataset[0])

<class 'tuple'>
torch.Size([2, 3, 500, 500])


UnboundLocalError: local variable 'img1_batch' referenced before assignment

In [47]:
composed = transforms.Compose([Rescale(256),
                               RaftTransforms()])

In [48]:
video_dataset = MRIDataset("/Users/men22/OneDrive - University of Sussex/FYP/Participants/VIDEOS_ALL/*.mp4", transform=composed)

In [51]:
video_dataset[0][0].shape

<class 'tuple'>
torch.Size([2, 3, 500, 500])


torch.Size([2, 3, 256, 256])