In [1]:
import time
import numpy as np
import torch
import torchvision
from matplotlib import pyplot as plt
from matplotlib import cm
from matplotlib.animation import FuncAnimation
from conv_sparse_model import ConvSparseLayer

from train_conv3d_sparse_model import load_balls_data
from train_conv3d_sparse_model import plot_original_vs_recon
from train_conv3d_sparse_model import plot_filters
from train_conv3d_sparse_model import plot_video

from video_loader import VideoLoader
from video_loader import MinMaxScaler

from IPython.display import HTML

In [2]:
def load_balls_data(batch_size):
    
    with open('ball_videos.npy', 'rb') as fin:
        ball_videos = torch.tensor(np.load(fin)).float()

    batch_size = batch_size
    train_loader = torch.utils.data.DataLoader(ball_videos,
                                               batch_size=batch_size,
                                               shuffle=True)

    return train_loader

In [2]:
def load_bamc_data(batch_size):
    video_path = "/home/cm3786@drexel.edu/bamc_data/"
    transforms = torchvision.transforms.Compose([torchvision.transforms.Grayscale(num_output_channels=1),
                                                 torchvision.transforms.Resize(size=(160, 90)), 
                                                 MinMaxScaler(0, 255)])
    dataset = VideoLoader(video_path, transform=transforms, num_frames=60)
    train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return train_loader

In [8]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
if device == "cpu":
    batch_size = 1
else:
    batch_size = 4*3

# train_loader = load_balls_data(batch_size)
train_loader = load_bamc_data(batch_size)

example_data = next(iter(train_loader))

sparse_layer = ConvSparseLayer(in_channels=1,
                               out_channels=16,
                               kernel_size=5,
                               stride=1,
                               padding=0,
                               convo_dim=3,
                               rectifier=True,
                               shrink=0.25,
                               lam=0.5,
                               max_activation_iter=200,
                               activation_lr=1e-2)
model = torch.nn.DataParallel(sparse_layer, device_ids=[1, 2, 3])
model.to(device)

learning_rate = 1e-2
filter_optimizer = torch.optim.Adam(sparse_layer.parameters(),
                                    lr=learning_rate)

In [9]:
example_data[1].shape

torch.Size([12, 1, 60, 90, 160])

In [None]:
ani = plot_video(example_data[1][0])
HTML(ani.to_html5_video())

In [None]:
for epoch in range(30):
    # for local_batch in train_loader:
    for labels, local_batch in train_loader:
        local_batch = local_batch.to(device)
        t1 = time.perf_counter()
        activations = model(local_batch)
        t2 = time.perf_counter()
        print('activations took {} sec'.format(t2-t1))
        loss = sparse_layer.loss(local_batch, activations)
        print('epoch={}, loss={}'.format(epoch, loss))
        print()

        filter_optimizer.zero_grad()
        loss.backward()
        filter_optimizer.step()
        sparse_layer.normalize_weights()

activations took 27.950331718020607 sec
epoch=0, loss=1549.7274169921875

activations took 23.936351343989372 sec
epoch=0, loss=1460.863525390625



In [None]:
activations = sparse_layer(example_data[1][:1].to(device))
reconstructions = sparse_layer.reconstructions(
    activations).cpu().detach().numpy()

ani = plot_original_vs_recon(example_data[1], reconstructions, idx=0)
HTML(ani.to_html5_video())

In [None]:
ani = plot_filters(sparse_layer.filters.cpu().detach())
HTML(ani.to_html5_video())