# Data Exploration

In [None]:
import sys
sys.path.insert(0,'../src')

In [None]:
# imports
import os
from itertools import combinations

import torch
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import seaborn as sns

from utils import *
from config import *
from data import ImageDataset, VideoDataset

In [None]:
# initialise data splits
data = { split: ImageDataset(split=split, include_classes=CLASSES, ratio=1.0) for split in SPLITS } 
id2class, class2id = data["train"].id2class, data["train"].class2id

In [None]:
# initialise data loaders
loader = { split: DataLoader(data[split], BATCH_SIZE) for split in SPLITS}

## Verify Disjointness of Splits

In [None]:
paths = {}
for split in SPLITS:
    paths[split] = set([path for path, _ in data[split].image_paths])

for pair in combinations(SPLITS, 2):
    fst, snd = pair
    print(f"{pair} has {len(paths[fst] & paths[snd])} images in common")

## Verify Even Class Distribution

In [None]:
fig, ax = plt.subplots(ncols=3, figsize=(4*3,3))
for i, split in enumerate(SPLITS):
    dist = data[split].class_distribution
    sns.barplot(x=list(dist.keys()), y=list(dist.values()), ax=ax[i])
    ax[i].tick_params(axis='x', rotation=90)
    ax[i].set(title=f"{split.capitalize()} Split")

## Show Examples

In [None]:
# train split
images, labels = next(iter(loader["train"]))

show_images(images, titles=[id2class[l.item()] for l in labels], show=True)

In [None]:
# val split
images, labels = next(iter(loader["val"]))

show_images(images, titles=[id2class[l.item()] for l in labels], show=True)

In [None]:
# test split
images, labels = next(iter(loader["test"]))

show_images(images, titles=[id2class[l.item()] for l in labels], show=True)

## Video Dataset

In [None]:
config = VideoDataset.default_config()
video_dataset = VideoDataset(**config)

In [None]:
import random
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np

clip, labels = video_dataset[random.randint(0, len(video_dataset))]
image_tensors = [image for image in clip]

# create a figure and axis object
fig, ax = plt.subplots()

# create an empty image object to hold the current frame
im = ax.imshow(np.zeros_like(image_tensors[0].permute(1,2,0)))

# define the update function that will be called for each frame
def update(i):
    im.set_data(image_tensors[i].permute(1, 2, 0))
    return [im]

ani = animation.FuncAnimation(fig, update, frames=len(image_tensors), interval=500, blit=True)

ani.save('animation.mp4', writer='ffmpeg')

from IPython.display import HTML
HTML(f'<video controls src="animation.mp4" />')

In [None]:
# number of clips
len(video_dataset)

In [None]:
# only clips with 10 frames in length
np.unique([len(clip) for clip, _ in video_dataset], return_counts=True)