In [None]:
import time
from datetime import datetime
import numpy as np
import torch
import torchvision
from matplotlib import pyplot as plt
from matplotlib import cm
from matplotlib.animation import FuncAnimation

import sys
sys.path.insert(0, '/home/dwh48@drexel.edu/sparse_coding_torch')

from feature_extraction.conv_sparse_model import ConvSparseLayer
from data_classifiers.small_data_classifier import SmallDataClassifier

from sklearn.model_selection import train_test_split

from feature_extraction.train_conv3d_sparse_model import load_balls_data
from feature_extraction.train_conv3d_sparse_model import plot_original_vs_recon
from feature_extraction.train_conv3d_sparse_model import plot_filters
from feature_extraction.train_conv3d_sparse_model import plot_video

from utils.load_data import load_bamc_clips, load_covid_data

from IPython.display import HTML

from tqdm import tqdm


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 1
    # batch_size = 3

# train_loader = load_balls_data(batch_size)
train_loader, _ = load_bamc_clips(batch_size, 1.0, sparse_model=None, device=None, num_frames=4, seed=42)
print('Loaded', len(train_loader), 'train examples')

example_data = next(iter(train_loader))

sparse_layer = ConvSparseLayer(in_channels=1,
                               out_channels=64,
                               kernel_size=(4, 16, 16),
                               stride=1,
                               padding=0,
                               convo_dim=3,
                               rectifier=True,
                               lam=0.05,
                               max_activation_iter=75,
                               activation_lr=1e-2)
sparse_layer.to(device)

In [None]:
# Load models if we'd like to
checkpoint = torch.load("/home/dwh48@drexel.edu/sparse_coding_torch/model-20211027-034737.pt")
sparse_layer.load_state_dict(checkpoint['model_state_dict'])

In [None]:
fp_ids = ['/shared_data/bamc_data_scale_cropped/PTX_Sliding/image_104548309385533_CLEAN.mov', '/shared_data/bamc_data_scale_cropped/PTX_Sliding/image_588413346180_CLEAN.mp4', '/shared_data/bamc_data_scale_cropped/PTX_Sliding/image_24164968068436_CLEAN.mp4', '/shared_data/bamc_data_scale_cropped/PTX_Sliding/image_104543812690743_CLEAN.mov', '/shared_data/bamc_data_scale_cropped/PTX_Sliding/image_1499268364374_clean.mov', '/shared_data/bamc_data_scale_cropped/PTX_Sliding/image_1511338287338_clean.mov']
fn_ids = ['/shared_data/bamc_data_scale_cropped/PTX_No_Sliding/image_610066411380_CLEAN.mov', '/shared_data/bamc_data_scale_cropped/PTX_No_Sliding/image_642169070951_clean.mp4', '/shared_data/bamc_data_scale_cropped/PTX_No_Sliding/image_1543571117118_clean.mp4', '/shared_data/bamc_data_scale_cropped/PTX_No_Sliding/image_6056976176281_CLEAN.mov', '/shared_data/bamc_data_scale_cropped/PTX_No_Sliding/image_27185428518326_CLEAN.mp4', '/shared_data/bamc_data_scale_cropped/PTX_No_Sliding/image_588695055398_clean.mov', '/shared_data/bamc_data_scale_cropped/PTX_No_Sliding/image_2418161753608_clean.mp4', '/shared_data/bamc_data_scale_cropped/PTX_No_Sliding/image_2454526567135_CLEAN.mp4', '/shared_data/bamc_data_scale_cropped/PTX_No_Sliding/image_584357289931_clean.mov', '/shared_data/bamc_data_scale_cropped/PTX_No_Sliding/image_27180764486244_CLEAN.mp4', '/shared_data/bamc_data_scale_cropped/PTX_No_Sliding/image_1884162273498_clean.mov', '/shared_data/bamc_data_scale_cropped/PTX_No_Sliding/image_417221672548_CLEAN.mp4', '/shared_data/bamc_data_scale_cropped/PTX_No_Sliding/image_426794579576_CLEAN.mp4', '/shared_data/bamc_data_scale_cropped/PTX_No_Sliding/image_1895283541879_clean.mov']

incorrect_sparsity = []
correct_sparsity = []
incorrect_filter_act = torch.zeros(64)
correct_filter_act = torch.zeros(64)

for labels, local_batch, vid_f in tqdm(train_loader):
    u_init = torch.zeros([1, sparse_layer.out_channels] +
                    sparse_layer.get_output_shape(example_data[1])).to(device)

    activations, _ = sparse_layer(local_batch.to(device), u_init)
    sparsity = torch.count_nonzero(activations) / torch.numel(activations)
    filter_act = torch.sum(activations.squeeze(), dim=[1, 2])
    filter_act = filter_act / torch.max(filter_act)
    filter_act = filter_act.detach().cpu()
    
    if vid_f[0] in fp_ids or vid_f[0] in fn_ids:
        incorrect_sparsity.append(sparsity)
        incorrect_filter_act += filter_act
    else:
        correct_sparsity.append(sparsity)
        correct_filter_act += filter_act
        
print(torch.mean(torch.tensor(correct_sparsity)))
print(torch.mean(torch.tensor(incorrect_sparsity)))
    

In [None]:
filters = sparse_layer.filters.cpu().detach()
print(filters.size())

filters = torch.stack([filters[val] for val in correct_filter_act.argsort(descending=True)])

print(filters.size())

ani = plot_filters(filters)
# HTML(ani.to_html5_video())
ani.save("/home/dwh48@drexel.edu/sparse_coding_torch/data_classifiers/outputs/kfold_3dcnn/correct_vis.mp4")