In [None]:
from lib.video_dataset import VideoFrameDataset
from config.dataset import get_dataset_path


In [None]:
import sys

sys.path.append("../")

from common.config.torch_config import get_transform, unnormalize, device
from common.utils.output import plot_tensor, plot_train_val_data


## Hiperparámetros

In [None]:
from config.const import *

In [None]:
data_path, model_path = get_dataset_path(
    dataset="WLASL/videos", model_name="WLASL_tanh_8"
)
multiple_transform = get_transform(IMAGE_SIZE)


In [None]:
dataset = VideoFrameDataset(
    root_path=data_path,
    transform=multiple_transform,
    image_size=IMAGE_SIZE,
    num_segments=NUM_SEGMENTS,
    frames_per_segment=FRAMES_PER_SEGMENT,
)

classes = dataset.classes
print(classes)

In [None]:
from utils.loader import split_dataset

In [None]:
train_loader, test_loader, validation_loader = split_dataset(
    dataset, train_split=0.70, validation_split=0.1, batch_size=BATCH_SIZE
)

In [None]:
print(len(train_loader), len(validation_loader), len(test_loader))

## Tensorboard logger y writter

In [None]:
from torch.utils.tensorboard.writer import SummaryWriter

In [None]:
writer = SummaryWriter("./tensorboard/logs")

## Ejemplo de entrada de la red

> Initial input = [BATCH_SIZE, NUMBER_OF_FRAMES, CHANNELS, HEIGHT, WIDTH]

In [None]:
from torchvision.utils import make_grid
from torch import nn

In [None]:
def plot_grid(
    tensor,
    dims=(1, 2, 0),
    nrow=FRAMES_PER_SEGMENT * NUM_SEGMENTS,
    unnorm=True,
    start_dim=0,
    end_dim=1,
):
    flat = nn.Flatten(start_dim=start_dim, end_dim=end_dim)  # Flatten batch to plot.

    flatted_tensor = flat(tensor)
    grid = make_grid(flatted_tensor.cpu(), nrow=nrow)

    if unnorm:
        grid = unnormalize(grid)

    plot_tensor(grid, dims)
    return grid

### Visualización de batch completo

In [None]:
first_batch, (ground_classes, ground_poses) = next(iter(train_loader))
grid = plot_grid(first_batch)

print(first_batch.shape)
print(ground_classes.shape)
print(ground_poses.shape)

### Hacemos log del grid en tensorboard

In [None]:
writer.add_image(f"Example of full batch with {BATCH_SIZE} videos", grid)

In [None]:
n_video = 0
frame_of_video = 5

### Muestra el tensor con un video completo

In [None]:
video = first_batch[n_video]
_ = plot_grid(video, nrow=len(video), end_dim=0)

### Muesta un único frame del video

In [None]:
img = video[frame_of_video]

In [None]:
target = ground_classes[n_video]

In [None]:
print(classes[target])
grid = plot_grid(img, end_dim=0)
writer.add_image("Example of image", grid)

### Muestra la pose de salida para ese frame

In [None]:
import matplotlib.pyplot as plt

In [None]:
video_poses = ground_poses[n_video]
img_pose = video_poses[frame_of_video]
POSES_PER_FRAME = img_pose.shape[0] * img_pose.shape[1]
print(img_pose.shape)

In [None]:
norm_img_pose = img_pose * IMAGE_SIZE
img_pose_transpose = norm_img_pose.T

In [None]:
x, y = img_pose_transpose[0], img_pose_transpose[1]

plt.scatter(x, y)
plt.gca().invert_yaxis()

### Mostramos la pose superpuesta sobre el frame

In [None]:
unorm_img = unnormalize(img)

plt.scatter(x, y)
plt.imshow(unorm_img.permute(1, 2, 0))

fig = plt.gcf()
writer.add_figure("Example of image with pose", fig)

# fig.show()
fig

### Pequeña prueba de una convolución

Hacemos esto para ver la salida tras aplicar filtros

In [None]:
hidden_1, hidden_2 = 16, 32

conv1 = nn.Conv3d(
    FRAMES_PER_SEGMENT * NUM_SEGMENTS,
    hidden_1,
    kernel_size=(2, 3, 3),
    stride=2,
    padding=1,
)
conv2 = nn.Conv3d(hidden_1, hidden_2, kernel_size=3, stride=2, padding=1)
relu = nn.LeakyReLU()
batch_1 = nn.BatchNorm3d(hidden_1)
batch_2 = nn.BatchNorm3d(hidden_2)

x = first_batch

x = conv1(x)
x = relu(x)
x = batch_1(x)

print(x.shape)

x = conv2(x)
x = relu(x)
x = batch_2(x)

print(x.shape)


_ = plot_grid(x, nrow=hidden_2)

## Comprobamos el estado de balanceo de los loaders del dataset

In [None]:
# from nets.common.utils.balance import check_balance_status

# print(check_balance_status(test_loader, classes))
# print(check_balance_status(validation_loader, classes))
# print(check_balance_status(train_loader, classes))

In [None]:
from lib.model import CNN

num_frames = FRAMES_PER_SEGMENT * NUM_SEGMENTS
model = CNN(
    num_classes=len(classes),
    num_frames=num_frames,
    image_size=IMAGE_SIZE,
    num_pose_points=POSES_PER_FRAME * num_frames,
)

# print(model)

In [None]:
from torchinfo import summary

summary(model)

In [None]:
writer.add_graph(model, first_batch)
writer.close()

### Mostramos un grafo del modelo con tensorboard

<img src="https://i.imgur.com/cvkNqyB.png" alt="Grafo modelo con tensorboard" width="400"/>

In [None]:
from lib.train import train_model

## Entrenamos la red

In [None]:
train_costs, val_costs, train_accs, val_accs = train_model(
    model,
    train_loader,
    validation_loader,
    device,
    learning_rate=LEARNING_RATE,
    num_epochs=NUM_EPOCHS,
    writer=writer,
)

### Plot de pérdida y accuracy

In [None]:
plot_train_val_data(train_costs, val_costs, ylabel="Costs")
plot_train_val_data(train_accs, val_accs, ylabel="Accuracy")

In [None]:
import sys

sys.path.append("../")

from common.utils.check_accuracy import check_accuracy

## Comprobamos el accuracy de la red en los tres sets

In [None]:
check_accuracy(train_loader, model, classes, device, n_batchs=10, has_pose=True)

In [None]:
check_accuracy(validation_loader, model, classes, device, has_pose=True)

In [None]:
check_accuracy(test_loader, model, classes, device, has_pose=True)


## Exportamos modelo

In [None]:
from torch import save, onnx, randn

In [None]:
save(model, model_path)
print(f"Model exported to {model_path}")

### Exportamos modelo en formato estandar ONNX

In [None]:
dummy_input = randn(
    BATCH_SIZE,
    FRAMES_PER_SEGMENT * NUM_SEGMENTS,
    3,
    IMAGE_SIZE,
    IMAGE_SIZE,
    device=device,
)

onnx.export(
    model,
    dummy_input,
    model_path.replace(".pth", ".onnx"),
    input_names=["input"],
    dynamic_axes={"input": {0: "batch_size"}},
)

### Mostramos el grafo de onnx con [_netrón_](https://netron.app/)

<img src="https://i.imgur.com/jDkeBMz.png" alt="Grafo del modelo exportado con netrón" width="400"/>
