In [None]:
import glob
import os
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import cv2
import time

In [None]:
from torchvision.datasets.vision import VisionDataset
from torch.utils.data import IterableDataset
# from torchvision.datasets.video_utils import VideoClips
from video_clip import VideoClips
import torch.utils.data as data
from bat_seg_models import ThreeLayerSemSegNetWideView, UNET, UNETTraditional
from frame_augmentors import MaskNormalize, Mask3dto2d, AddDim, ToFloat, MaskCompose, MaskToTensor
import bat_functions
from CountLine import CountLine

In [None]:
import matplotlib.pyplot as plt
im_file = ".../kasanka-bats/frames/17Nov/card-f/GP039791/GP039791_15948.jpg"
im = plt.imread(im_file)

im.shape

In [None]:
root_output_folder = '.../kasanka-bats/processed/deep-learning/corrected_model'
date = '16Nov'
os.makedirs(root_output_folder, exist_ok=True)

raw_camera_folders = sorted(glob.glob('.../kasanka-bats/gopros/{}/*'.format(date)))

camera_folders = []
for camera_folder in raw_camera_folders:
    videos = sorted(glob.glob(os.path.join(camera_folder, '*.[Mm][Pp]4')))
    camera_name = camera_folder.split('/')[-1]
    if not os.path.exists(os.path.join(root_output_folder, date, camera_name, 'centers.npy')):
        print(*videos, sep='\n')
        print('--------------')
        camera_folders.append(camera_folder)
        



In [None]:
class BatIterableDataset(IterableDataset):
    def __init__(self, video_files, augmentor=None, max_bad_reads=300):
        self.vid_cap = cv2.VideoCapture(video_files[0])
        self.video_files = video_files
        assert self.vid_cap.isOpened()
        self.more_frames = True
        # How many times a frame can come up false 
        # before assuming end of video
        self.max_bad_reads = max_bad_reads
        self.total_frames_read = 0
        self.total_bad_reads = 0
        self.augmentor = augmentor
        self.video_number = 0
        
    def more_videos(self):
        return self.video_number < len(self.video_files)
    
    def start_next_video(self):
        if self.vid_cap.isOpened():
            self.vid_cap.release()
        self.video_number += 1
        if self.video_number < len(self.video_files):
            print('starting new video')
            print(self.get_read_frame_info())
            self.vid_cap = cv2.VideoCapture(self.video_files[self.video_number])
        
    def video_generator(self):
        while(self.vid_cap.isOpened() or self.more_videos()):
            if not self.vid_cap.isOpened():
                self.start_next_video()
            good_read = False
            num_bad_reads = 0
            while (not good_read and (num_bad_reads < self.max_bad_reads)):
                grabbed, frame = self.vid_cap.read()
                if grabbed:
                    good_read = True
                    self.total_frames_read += 1
                    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    frame = {'image': frame[2:-2, 2:-2]}
                    if np.mean(frame['image'][::50,::50, 2]) < 5:
                        print('too dark')
                        self.vid_cap.release()
                        break
                        
                    if self.augmentor:
                        frame = self.augmentor(frame)
                    yield frame
                else:
                    num_bad_reads += 1
                    self.total_bad_reads += 1
            if not good_read:
                self.vid_cap.release()
                print("video capture closed")
            
    def __iter__(self):
        return self.video_generator()
    
    def __del__(self):
        if self.vid_cap.isOpened():
            self.vid_cap.release()
    
    def is_more_frames(self):
        return self.vid_cap.isOpened()
    
    def get_read_frame_info(self):
        print('{} frames have been read with {} bad reads'.format(self.total_frames_read,
                                                                  self.total_bad_reads))

In [None]:
folder = './models'

# model_filename = 'model_ThreeLayerWide_epochs_10_batcheff_4_lr_0.05_momentum_0.5_aug_aug-no-blur-2d-17Nov-big-dataset.tar'
model_filename = 'model_UNET_epochs_100_batcheff_16_lr_0.01_momentum_0.9_aug_aug-2d-20Nov-big-dataset.tar'
model_filename = 'model_UNET_epochs_100_batcheff_16_lr_0.01_momentum_0.9_aug_better-norm-aug-2d-20Nov-big-dataset.tar'
model_file = os.path.join(folder, model_filename)
model_file = './models/model_UNETTraditional_epochs_100_batcheff_16_lr_0.01_momentum_0.9_aug_better-norm-aug-2d-20Nov-big-dataset.tar'

In [None]:
root_train_folder = ".../kasanka-bats/annotations"
mean = np.load(os.path.join(root_train_folder, 'mean.npy'))
std = np.load(os.path.join(root_train_folder, 'std.npy'))

channel = 2
       
    

# augmentor = None
bat_datasets = []
for camera_folder in camera_folders:
    videos = sorted(glob.glob(os.path.join(camera_folder, '*.[Mm][Pp]4')))
    augmentor = MaskCompose([Mask3dto2d(channel_to_use=channel),
                         MaskToTensor(),
                         MaskNormalize(mean[channel]/255, std[channel]/255),
                        ])
    bat_dataset = BatIterableDataset(videos, augmentor=augmentor)
    save_folder = os.path.join(root_output_folder, *camera_folder.split('/')[-2:])
    os.makedirs(save_folder, exist_ok=True)
    os.makedirs(os.path.join(save_folder, 'example-frames'), exist_ok=True)
    bat_datasets.append({'dataset':bat_dataset,
                         'save_folder': save_folder})
                                    

In [None]:
for i in bat_datasets[0]['dataset']:
    break
# dataloader = data.DataLoader(bat_dataset['dataset'], 
#                                  batch_size=batch_size,
#                                  shuffle=False, num_workers=0, 
#                                  pin_memory=True)

In [None]:
plt.figure(figsize=(20,20))
# plt.imshow(((np.squeeze(i['image']) - mean) / std)[:,:,1])
plt.imshow(np.squeeze(i['image']))
plt.colorbar()
# plt.figure(figsize=(20,20))


In [None]:
save_folder

In [None]:
def logit2prob(logit):
    e_l = np.e ** logit
    return e_l 

def denorm_image(im, mean, std):
    """ Take image the was normalized and return to 0 to 255"""
#     im = np.copy(im)
    im *= std
    im += mean
    im *= 255
    im = np.maximum(im, 0)
    im = np.minimum(im, 255)
    im = im.astype(np.uint8)
    
    return im

should_plot = False
should_save = True

num_classes = 2
bat_prob_thresh = .6
batch_size = 2
early_stop = None
# save some original frames to check detection quality
save_every_n_frames = 1350
channel = 2


device = torch.device("cuda")
model = UNETTraditional(1, 2, should_pad=False)
model.load_state_dict(torch.load(model_file))
model.to(device)

model.train(False)

for bat_dataset in bat_datasets[:]:

    num_frames = 0
    running_loss = 0
    
    print(bat_dataset['save_folder'])

    dataloader = data.DataLoader(bat_dataset['dataset'], 
                                 batch_size=batch_size,
                                 shuffle=False, num_workers=0, 
                                 pin_memory=True)

    centers_list = []
    contours_list = []
    sizes_list = []
    rects_list = []


    for batch_ind, batch in enumerate(dataloader):
        if batch_ind == 0:
            print('started...')
            t0 = time.time()
        if early_stop:
            if batch_ind >= early_stop:
                break

        im_batch = batch['image'].cuda()
    #     masks = batch['mask'].cuda()

        with torch.no_grad():
            outputs = model(im_batch)
            masks = (outputs[:, 1].cpu().numpy() > np.log(bat_prob_thresh)).astype(np.uint8)
            
            for ind, mask in enumerate(masks):
                centers, areas, contours, _, _, rects = bat_functions.get_blob_info(mask)
                centers_list.append(centers)
                sizes_list.append(areas)
                contours_list.append(contours)
                rects_list.append(rects)
                if save_every_n_frames:
                    if num_frames % save_every_n_frames == 0:
                        day = bat_dataset['save_folder'].split('/')[-2]
                        card = bat_dataset['save_folder'].split('/')[-1]
                        im_name = '{}_{}_obs-ind_{}.jpg'.format(day, card, num_frames)
                        im_file = os.path.join(bat_dataset['save_folder'], 
                                               'example-frames', im_name)
                        im = np.squeeze(batch['image'][ind].numpy())
                        im = denorm_image(im, mean[channel]/255, std[channel]/255)
                        cv2.imwrite(im_file, im)
                num_frames += 1


        if should_plot:
            for ind in range(len(im_batch)):

                if 'orig' in batch.keys():
                    plt.figure(figsize=(10,10))
                    plt.imshow(batch['orig'][ind])
                plt.figure(figsize=(10,10))
                im = im_batch[ind].cpu().numpy()
                im = np.transpose(im, (1, 2, 0))
                plt.imshow(im)
                plt.figure(figsize=(10,10))
                im = outputs[ind][0].cpu().numpy()
                plt.imshow(im)
                plt.title('output')
                prob = logit2prob(outputs[ind,1].cpu().numpy())
                mask = (prob > 0.5).astype(np.uint8)

    #             display_im = np.zeros_like(im)
    #             display_im[..., 0] = masks[ind]
                plt.figure(figsize=(10,10))
                plt.imshow(mask)
    #             plt.colorbar()
    #             plt.figure(figsize=(10,10))
    #             plt.imshow(display_im)

    total_time = time.time() - t0
    print(total_time, total_time / batch_ind / batch_size)
    print(bat_dataset['dataset'].get_read_frame_info())
    if should_save:
        save_folder = bat_dataset['save_folder']
        num_contour_files = 15
        file_num = 0
        new_contours = []
        for frame_ind, cs in enumerate(contours_list):
            if frame_ind % int(len(contours_list)/num_contour_files) == 0:
                # start new file
                file_name = f'contours-compressed-{file_num:02d}.npy'
                file = os.path.join(save_folder, file_name)
                np.save(file, np.array(new_contours, dtype=object))
                new_contours = []
                file_num += 1
            new_contours.append([])
            for c in cs:
                cc	= np.squeeze(cv2.approxPolyDP(c, 0.1, closed=True))
                new_contours[-1].append(cc)
        file_name = f'contours-compressed-{file_num:02d}.npy'
        file = os.path.join(save_folder, file_name)
        np.save(file, np.array(new_contours, dtype=object))
        np.save(os.path.join(save_folder, 'size.npy'), sizes_list)
        np.save(os.path.join(save_folder,'rects.npy'), rects_list)
        np.save(os.path.join(save_folder, 'centers.npy'), centers_list)