In [8]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import numpy as np
import cv2
import matplotlib.pyplot as plt
import pickle
import gzip
import numpy as np
import os

In [9]:
# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        print(image.shape)
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

In [10]:
def load_zipped_pickle(filename):
    with gzip.open(filename, 'rb') as f:
        loaded_object = pickle.load(f)
        return loaded_object

In [11]:
def save_zipped_pickle(obj, filename):
    with gzip.open(filename, 'wb') as f:
        pickle.dump(obj, f, 2)

In [16]:
# load data
trainX = load_zipped_pickle("task3data/train.pkl")
testX = load_zipped_pickle("task3data/test.pkl")
samples = load_zipped_pickle("task3data/sample.pkl")

In [14]:
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset

In [17]:
trainX_extr = []
for i in range(len(trainX)):
    video = trainX[i]["video"]
    frames = trainX[i]["frames"]
    frame1 = trainX[i]["video"][:,:,frames[0]]
    frame2 = trainX[i]["video"][:,:,frames[1]]
    frame3 = trainX[i]["video"][:,:,frames[2]]
    label1 = trainX[i]["label"][:,:,frames[0]]
    label2 = trainX[i]["label"][:,:,frames[1]]
    label3 = trainX[i]["label"][:,:,frames[2]]
    trainX_extr.append((frame1,label1))
    trainX_extr.append((frame2,label2))
    trainX_extr.append((frame3,label3))

In [19]:
test_data = []
for i in range(len(testX)):
    video = testX[i]["video"]
    for j in range(video.shape[2]):
        test_data.append(video[:,:,j])

In [18]:
class Dataset2(BaseDataset):
    """mitral valve Dataset. Read images, apply augmentation and preprocessing transformations.
    
    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)
    
    """
    
    CLASSES = ['valve']
    
    def __init__(
            self, 
            data,
            classes=None, 
            augmentation=None, 
            preprocessing=None,
    ):
        self.data = data
        self.images_fps = [pair[0] for pair in self.data]
        self.masks_fps = [pair[1] for pair in self.data]
        
        # convert str names to class values on masks
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]
        
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __getitem__(self, i):
        
        # read data
#        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
#        mask = cv2.imread(self.masks_fps[i], 0)
        
        image = self.images_fps[i][:,:,np.newaxis]
        mask = self.masks_fps[i]
        print(np.array(image).shape)
        print(np.array(mask).shape)
        
        # extract certain classes from mask (e.g. cars)
        masks = [(mask == v) for v in self.class_values]
        mask = np.stack(masks, axis=-1).astype('float')
        
        print(np.array(image).shape)
        print(np.array(mask).shape)
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
            
        print(f"image has shape {np.array(image).shape}")
        print(f"mask has shape{np.array(mask).shape}")
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        print(np.array(image).shape)
        print(np.array(mask).shape)  
        return image, mask
        
    def __len__(self):
        return len(self.data)