In [None]:
import os
import cv2
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from functools import partial
plt.style.use('default')

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor

from utils.dataset import VideoDataset, MyConcatDataset, VideoDatasetRNN
from utils.models import TrackNetV2MSE, TrackNetV2NLL, TrackNetV2RNN
from utils.training import train_model

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

%load_ext autoreload
%autoreload 2

# Dataset optimizations

In [None]:
dataset_params = dict(image_size=(360, 640),
                      sequence_length=4,
                      sigma=5,
                      drop_duplicate_frames=False,
                      transform = ToTensor(),
                      target_transform = ToTensor(),
                      grayscale=False)

dataset = VideoDatasetRNN(root="../datasets/prova/", **dataset_params)

In [None]:
from torch.utils.data._utils.collate import default_collate

sequence_length = 4
clear_probability = 0.9

def collate_fn(batch):
    frames, labels = default_collate(batch)

    x = frames.clone()

    for i in range(len(batch)):
        if torch.rand(1) < clear_probability:
            to_delete = torch.randint(low=1, high=sequence_length, size=(1,))
            x[i, :to_delete] = torch.zeros(to_delete, x.shape[2], x.shape[3])
    return x, torch.zeros(len(batch)), labels

dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

input, a, labels = next(iter(dataloader))
input = input.to(torch.float32)

In [None]:
plt.imshow(input[0,3], cmap='gray')

In [None]:
model = TrackNetV2MSE(sequence_length=4)
model.load('checkpoints/tracknet_v2_mse_360_640_4f/checkpoint_0020_best.ckpt')
model.eval();

In [None]:
model = TrackNetV2RNN(sequence_length=4)
model.eval();

In [None]:
with torch.no_grad():
    output = model(input)

In [None]:
print(input[0].shape)
print(output.shape)

In [None]:
output[0][i].ravel()

In [None]:
input_hm = input[0]

plt.imshow(input_hm[0][-1], cmap='gray')
plt.show()
plt.imshow(output[0][0], cmap='gray')

In [None]:
with torch.autocast(device_type='cpu'):
    with torch.no_grad():
        output = model(input)

In [None]:
from torch.optim.lr_scheduler import OneCycleLR, StepLR
from torch.optim import Adam
from utils.models import TrackNetV2MSE

model = TrackNetV2MSE()
optimizer = Adam(model.parameters())
scheduler = OneCycleLR(optimizer=optimizer, max_lr=1e-2, epochs=10, steps_per_epoch=100)

scheduler.total_steps

# TODO: Look for the scheduler for probabilities
class MyScheduler(OneCycleLR):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def get_probability():
        scheduler.total_steps

s = MyScheduler(optimizer=optimizer, max_lr=1e-2, epochs=10, steps_per_epoch=100)
s.last_epoch

Trying to find the bottleneck

In [None]:
t = time.time()
for i in range(1000):
    dataset._generate_heatmap(np.random.randint(749, 780))
print(time.time()-t)

In [None]:
t = time.time()
for i in range(1000):
    dataset._get_heatmap(np.random.randint(749, 780))
print(time.time()-t)

t = time.time()
for i in range(1000):
    dataset_2._get_heatmap(np.random.randint(749, 780))
print(time.time()-t)

Bottleneck is in the ToTensor() Transform

In [None]:
t = time.time()
for i in range(100):
    dataset[np.random.randint(0, 100)]
print(time.time()-t)

t = time.time()
for i in range(10):
    dataset_2[np.random.randint(0, 100)]
print(time.time()-t)

In [None]:
t = time.time()
for i in range(100):
    dataset[np.random.randint(0, 100)]
print(time.time()-t)

t = time.time()
for i in range(10):
    dataset_2[np.random.randint(0, 100)]
print(time.time()-t)

In [None]:
t = time.time()
for i in range(100):
    dataset[np.random.randint(0, 100)]
print(time.time()-t)

t = time.time()
for i in range(10):
    dataset_2[np.random.randint(0, 100)]
print(time.time()-t)

In [None]:
frames, labels = dataset[0]
frame1 = frames[0]

frames, labels = dataset[10]
frame2 = frames[0]

In [None]:
t = time.time()
for i in range(1000):
    dataset._equal_frames(frame1, frame2)
print(time.time()-t)

# Visualize some activations and kernels because why not

In [None]:
model = TrackNetV2MSE(sequence_length=3)
model.load('checkpoints/tracknet_v2_mse_360_640/checkpoint_0027_best.ckpt')
model.eval()
model

In [None]:
dataset_params = dict(image_size=(360, 640),
                      sequence_length=3,
                      sigma=5,
                      drop_duplicate_frames=False,
                      heatmap_mode='image',
                      transform = ToTensor(),
                      target_transform = ToTensor(),
                      grayscale=False)

dataset = VideoDataset(root="../datasets/prova/", **dataset_params)

In [None]:
counter = 0

def get_encoding_layer(desired_block=1, subblock=0):
    layers = []
    for i, block in enumerate(model.children()):
        # print(i)
        if i%2 == 1:
            layers.append(block)
        for j, block_element in enumerate(block.children()):
            #print(i, j)
            for k, layer in enumerate(block_element.children()):
                layers.append(layer)
                # print(i, j, k)
                if type(layer) is torch.nn.ReLU and i==2*desired_block and j==subblock:
                    break
            if type(layer) is torch.nn.ReLU and i==2*desired_block and j==subblock:
                break
        if type(layer) is torch.nn.ReLU and i==2*desired_block:
            break
    return layers

def compute_activations(layers, input):
    activation = input.unsqueeze(dim=0)
    with torch.no_grad():
        for l in layers:
            activation = l(activation)

    return activation.squeeze().numpy()

In [None]:
frames, labels = dataset[50]
frames = frames.to(torch.float32)

In [None]:
w, h, dpi = 300*2*16/9, 300, 100

fig, axs = plt.subplots(ncols=2, figsize=(w/dpi, h/dpi), dpi=dpi)

axs[0].imshow(frames[-3:].numpy().transpose(1, 2, 0))
axs[0].set_title("Input frame (last in sequence)")

axs[1].imshow(labels[0])
axs[1].set_title("Ground truth")

fig.tight_layout(pad=0.2)
plt.show()

In [None]:
noise_part = np.linspace(0, 1, 10)
c = []

for n in noise_part:
    with torch.no_grad():
        f = (1-n)*frames + n*torch.randn(frames.shape)
        out = model(f.unsqueeze(dim=0)).squeeze().numpy()
    c.append(out.max())
plt.plot(noise_part, c)

In [None]:
n = 0.07
with torch.no_grad():
    f = (1-n)*frames + n*torch.randn(frames.shape)
    out = model(f.unsqueeze(dim=0)).squeeze().numpy()
plt.imshow(out)
plt.colorbar()
plt.show()

In [None]:
block = 2
subblock = 2

activations = compute_activations(get_encoding_layer(block, subblock), frames)
activations.shape

In [None]:
(dead_activations, ) = np.where(activations.max(axis=(1,2))==0)
print(f"Of {activations.shape[0]} activations, {dead_activations.size} are dead and {activations.shape[0]-dead_activations.size} are not.")

In [None]:
height_pixels = 1080
top_adjust = 1

w, h, dpi = height_pixels*16/9*top_adjust, height_pixels, 100
fig, axs = plt.subplots(nrows=8, ncols=8, figsize=(w/dpi, h/dpi), dpi=dpi)

i_0 = 0

for i, ax in enumerate(axs.ravel()):
    ax.imshow(activations[i+i_0], cmap='gray')
    # ax.set_title(i)
    ax.set_axis_off()

#fig.suptitle(f"Activations in encoding block {block}, subblock {subblock}")

fig.tight_layout(pad=0.5)
fig.subplots_adjust(top=top_adjust)

fig.savefig(f"{block}_{subblock}.png")

plt.show()

In [None]:
model.state_dict().keys()

In [None]:
k = 4

kernels = model.state_dict()['vgg_conv1.1.0.weight'].numpy()
biases = model.state_dict()['vgg_conv1.1.0.bias'].numpy()
w, h, dpi = 800, 800, 100
fig, axs = plt.subplots(nrows=8, ncols=8, figsize=(w/dpi, h/dpi), dpi=dpi)

print(kernels.shape)
print(biases[k])

min_val = kernels[k].min()
max_val = kernels[k].max()
print(min_val, max_val)

max_val=max((max_val, -min_val))
min_val=min((-max_val, min_val))

for i, ax in enumerate(axs.ravel()):
    ax.imshow(kernels[k,i], cmap='RdBu', vmin=min_val, vmax=max_val)
    ax.set_axis_off()

#fig.suptitle(f"Kernel {k}, bias = {biases[k]:.2g}")
fig.tight_layout(pad=0.2)
plt.show()
