# Demo VideoDataFrame Class 
This demo uses the Algonauts dataset.
    
TABLE OF CODE CONTENTS:
1. Minimal demo without image transforms
2. Minimal demo without sparse temporal sampling for single continuous frame clips, without image transforms
3. Demo with image transforms
4. Demo with image transforms and dataloader
5. Demo with image transforms, dataloader and K-fold Cross-Validation

For more details about the VideoDataFrame Class, see the [VideoDataset Repo](https://video-dataset-loading-pytorch.readthedocs.io/en/latest/VideoDataset.html)

### Setup 

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import os.path as op
from pathlib import Path
import Buzznauts as buzz
buzz_root = Path(buzz.__path__[0]).parent.absolute()

# Data paths
fmri_dir = op.join(buzz_root, "data", "fmri")
stimuli = op.join(buzz_root, "data", "stimuli") 
videos_dir = op.join(stimuli, "videos")
frames_dir = op.join(stimuli, "frames")
annotation_file = op.join(frames_dir, 'annotations.txt')

In [None]:
from torchvision import transforms
import torch
from Buzznauts.data.utils import plot_video_frames
from Buzznauts.data.videodataframe import VideoFrameDataset, ImglistToTensor, FrameDataset

In [None]:
from Buzznauts.utils import seed_worker, set_generator

### Demo 1 - Sampled Frames, without Image Transforms

In [None]:
dataset = VideoFrameDataset(
    root_path=frames_dir,
    annotationfile_path=annotation_file,
    num_segments=3,
    frames_per_segment=1,
    imagefile_template='img_{:05d}.jpg',
    transform=None,
    random_shift=True,
    test_mode=False)

sample = dataset[0]
frames = sample[0]  # list of PIL images
label = sample[1]   # integer label

plot_video_frames(rows=1, cols=3, frame_list=frames, plot_width=15., plot_height=3.)

### Demo 2 - Single Continuous Frame Clip instead of Sampled Frames, without Image Transforms 

In [None]:
dataset = VideoFrameDataset(
        root_path=frames_dir,
        annotationfile_path=annotation_file,
        num_segments=1,
        frames_per_segment=9,
        imagefile_template='img_{:05d}.jpg',
        transform=None,
        random_shift=True,
        test_mode=False)

sample = dataset[5]
frames = sample[0]  # list of PIL images
label = sample[1]  # integer label

plot_video_frames(rows=3, cols=3, frame_list=frames, plot_width=10., plot_height=5.)

### Demo 3 - Sampled Frames, with Image Transforms 

In [None]:
def denormalize(video_tensor):
    """Undoes mean/standard deviation normalization, zero to one scaling, and channel rearrangement for a batch of images.
    
    Parameters
    ----------
    video_tensor : tensor.FloatTensor 
        A (FRAMES x CHANNELS x HEIGHT x WIDTH) tensor
        
    Returns
    ----------
    video_array : numpy.ndarray[float]
        A (FRAMES x CHANNELS x HEIGHT x WIDTH) numpy array of floats
    """
    inverse_normalize = transforms.Normalize(
        mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
        std=[1 / 0.229, 1 / 0.224, 1 / 0.225])
    return (inverse_normalize(video_tensor) * 255.).type(torch.uint8).permute(0, 2, 3, 1).numpy()

In [None]:
num_segments = 5
frames_per_segment = 1
total_frames = num_segments * frames_per_segment

# As of torchvision 0.8.0, torchvision transforms support batches of images
# of size (BATCH x CHANNELS x HEIGHT x WIDTH) and apply deterministic or random
# transformations on the batch identically on all images of the batch. Any torchvision
# transform for image augmentation can thus also be used  for video augmentation.
normalize = transforms.Compose([
    ImglistToTensor(), # list of PIL images to (FRAMES x CHANNELS x HEIGHT x WIDTH) tensor
    transforms.Resize(128), # image batch, resize smaller edge to 128
    transforms.CenterCrop((100, 128)), # image batch, center crop to square 128x128
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
    
preprocess = transforms.Compose([
    transforms.RandomAffine(degrees=15, translate=(0.05, 0.05), scale=(0.78125, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
])

dataset_without_preprocessing = VideoFrameDataset(
    root_path=frames_dir,
    annotationfile_path=annotation_file,
    num_segments=num_segments,
    frames_per_segment=frames_per_segment,
    imagefile_template='img_{:05d}.jpg',
    transform=normalize,
    random_shift=False,
    test_mode=False
)

dataset_with_preprocessing = FrameDataset(
    videoframedataset=dataset_without_preprocessing,
    transform=preprocess
)

In [None]:
len(dataset_with_preprocessing)

In [None]:
print('Sample without Preprocessing')
print('----------------------------')
sample_without_preprocessing = dataset_without_preprocessing[2]
frame_tensor = sample_without_preprocessing[0]  # tensor of shape (NUM_SEGMENTS*FRAMES_PER_SEGMENT) x CHANNELS x HEIGHT x WIDTH
frame_array = denormalize(frame_tensor)
plot_video_frames(rows=1, cols=5, frame_list=frame_array, plot_width=15., plot_height=3.)

print('Sample with Preprocessing')
print('-------------------------')
frame_tensor = torch.stack([dataset_with_preprocessing[i][0] for i in range(2*total_frames, 3*total_frames)], dim=0)
frame_array = denormalize(frame_tensor)
plot_video_frames(rows=1, cols=5, frame_list=frame_array, plot_width=15., plot_height=3.)

### Demo 4 - Sampled Frames Dataloader, with Image Transforms and Dalaloader

In [None]:
dataloader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=2,
        shuffle=True,
        num_workers=2,
        pin_memory=True,
        worker_init_fn=seed_worker,
        generator=set_generator())

for epoch in range(10):
    for video_batch, labels in dataloader:
        """
        Insert Training Code Here
        """
        print(labels)
        print("\nVideo Frames Batch Tensor Size:", video_batch.size())
        break
    break

### Demo 5 - K-fold Cross-Validation with Sampled Frames Dataloader and Image Transforms

In [None]:
from sklearn.model_selection import KFold

In [None]:
num_segments = 5
frames_per_segment = 6
total_frames = num_segments * frames_per_segment

# As of torchvision 0.8.0, torchvision transforms support batches of images
# of size (BATCH x CHANNELS x HEIGHT x WIDTH) and apply deterministic or random
# transformations on the batch identically on all images of the batch. Any torchvision
# transform for image augmentation can thus also be used  for video augmentation.
normalize = transforms.Compose([
    ImglistToTensor(), # list of PIL images to (FRAMES x CHANNELS x HEIGHT x WIDTH) tensor
    transforms.Resize(128), # image batch, resize smaller edge to 128
    transforms.CenterCrop((100, 128)), # image batch, center crop to square 128x128
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
    
preprocess = transforms.Compose([
    transforms.RandomAffine(degrees=15, translate=(0.05, 0.05), scale=(0.78125, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
])

dataset_without_preprocessing = VideoFrameDataset(
    root_path=frames_dir,
    annotationfile_path=annotation_file,
    num_segments=num_segments,
    frames_per_segment=frames_per_segment,
    imagefile_template='img_{:05d}.jpg',
    transform=normalize,
    random_shift=False,
    test_mode=False
)

dataset = FrameDataset(
    videoframedataset=dataset_without_preprocessing,
    transform=preprocess
)

In [None]:
pretrained_path = '/content/drive/MyDrive/Buzznauts/data/pretrained/vaegan_enc_weights.pickle'

def reset_weights(model, pretrained_path):
    """Try resetting model weights to avoid weight leakage.
    
    Parameters
    ----------
    model: torch.nn.Module
    """
    for layer in model.children():
        if hasattr(layer, 'reset_parameters'):
            print(f'Reset trainable parameters of layer = {layer}')
            layer.reset_parameters()
            
    pretrained = load_vaegan_weights(model, pretrained_path)
    model.load_my_state_dict(pretrained)

In [None]:
# Configuration options
k_folds = 5
num_epochs = 1
K_VAE = 1024 # size of the latent space vector

# Define the K-fold Cross Validator
kfold = KFold(n_splits=k_folds, shuffle=True)

# K-fold Cross Validation model evaluation
for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset)):
    print(f'FOLD {fold}')
    print('-------------------------')
    
    # Sample elements randomly from a given list of idx, no replacement
    train_subsampler = torch.utils.data.SubsetRandomSampler(train_idx)
    val_subsampler = torch.utils.data.SubsetRandomSampler(val_idx)
    
    # Define data loaders for training and testing data in this fold
    train_dataloader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=64,
        sampler=train_subsampler,
        num_workers=2,
        pin_memory=True,
        worker_init_fn=seed_worker,
        generator=set_generator())
    
    val_dataloader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=64,
        sampler=val_subsampler,
        num_workers=2,
        pin_memory=True,
        worker_init_fn=seed_worker,
        generator=set_generator())
    
    # Init the neural network
    
    network = ConvVarAutoEncoder(K=K_VAE)
    network.apply(reset_weights)