In [1]:
import pytorch_lightning as pl
from torchvision.models.video import r3d_18, MViT, R3D_18_Weights, s3d, S3D_Weights
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader, Subset
import wandb
from pytorch_lightning.loggers import WandbLogger
import torchvision.transforms as T
import yaml
from glob import glob
import os
from dataset import ApasDataset
import pytorchvideo.transforms as Tvid
from torchmetrics.functional import accuracy, auroc, f1_score
from pytorch_lightning.callbacks import ModelCheckpoint
from model import Model

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
fold = 0
num_workers = 12
epochs = 30
inputs = ["side", "top", "kinematics"]
batch_size = 8


mean=[0.43216, 0.394666, 0.37645]
std=[0.22803, 0.22145, 0.216989]
transform = T.Compose([
    T.Normalize(mean=mean, std=std),
])

In [3]:
train_ds = ApasDataset(fold, "valid", inputs)

train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
trainer = pl.Trainer(devices=1,
                     accelerator="gpu",
                     max_epochs=epochs,
                     num_sanity_val_steps=0,
                     limit_train_batches=0.2,
                    #  limit_val_batches=0.1,
                    # default_root_dir=save_dir,
                    log_every_n_steps=5000//batch_size,
                     )

Kinematics LOADED
Videos LOADED


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Labels LOADED


In [3]:
model = Model(inputs=["kinematics"],early_fusion=False, transform=transform)
model

Model(
  (k_backbone): Sequential(
    (0): Conv1d(36, 72, kernel_size=(3,), stride=(2,), padding=(1,))
    (1): ReLU()
    (2): Conv1d(72, 144, kernel_size=(3,), stride=(2,), padding=(1,))
    (3): ReLU()
    (4): Conv1d(144, 288, kernel_size=(3,), stride=(2,), padding=(1,))
    (5): ReLU()
    (6): AvgPool1d(kernel_size=(2,), stride=(2,), padding=(0,))
    (7): Flatten(start_dim=1, end_dim=-1)
  )
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=288, out_features=6, bias=True)
  )
  (loss): CrossEntropyLoss()
  (f1): MulticlassF1Score()
)

In [5]:
ckpt_path = "/home/student/project/not_used/fold0_['kinematics']_S3D_early_fuse/gesture_recognition/6mttlmpn/checkpoints/epoch=19-step=13460.ckpt"
model.load_from_checkpoint(ckpt_path)

RuntimeError: Error(s) in loading state_dict for Model:
	Missing key(s) in state_dict: "i_backbone.features.0.0.0.weight", "i_backbone.features.0.0.1.weight", "i_backbone.features.0.0.1.bias", "i_backbone.features.0.0.1.running_mean", "i_backbone.features.0.0.1.running_var", "i_backbone.features.0.1.0.weight", "i_backbone.features.0.1.1.weight", "i_backbone.features.0.1.1.bias", "i_backbone.features.0.1.1.running_mean", "i_backbone.features.0.1.1.running_var", "i_backbone.features.2.0.weight", "i_backbone.features.2.1.weight", "i_backbone.features.2.1.bias", "i_backbone.features.2.1.running_mean", "i_backbone.features.2.1.running_var", "i_backbone.features.3.0.0.weight", "i_backbone.features.3.0.1.weight", "i_backbone.features.3.0.1.bias", "i_backbone.features.3.0.1.running_mean", "i_backbone.features.3.0.1.running_var", "i_backbone.features.3.1.0.weight", "i_backbone.features.3.1.1.weight", "i_backbone.features.3.1.1.bias", "i_backbone.features.3.1.1.running_mean", "i_backbone.features.3.1.1.running_var", "i_backbone.features.5.branch0.0.weight", "i_backbone.features.5.branch0.1.weight", "i_backbone.features.5.branch0.1.bias", "i_backbone.features.5.branch0.1.running_mean", "i_backbone.features.5.branch0.1.running_var", "i_backbone.features.5.branch1.0.0.weight", "i_backbone.features.5.branch1.0.1.weight", "i_backbone.features.5.branch1.0.1.bias", "i_backbone.features.5.branch1.0.1.running_mean", "i_backbone.features.5.branch1.0.1.running_var", "i_backbone.features.5.branch1.1.0.0.weight", "i_backbone.features.5.branch1.1.0.1.weight", "i_backbone.features.5.branch1.1.0.1.bias", "i_backbone.features.5.branch1.1.0.1.running_mean", "i_backbone.features.5.branch1.1.0.1.running_var", "i_backbone.features.5.branch1.1.1.0.weight", "i_backbone.features.5.branch1.1.1.1.weight", "i_backbone.features.5.branch1.1.1.1.bias", "i_backbone.features.5.branch1.1.1.1.running_mean", "i_backbone.features.5.branch1.1.1.1.running_var", "i_backbone.features.5.branch2.0.0.weight", "i_backbone.features.5.branch2.0.1.weight", "i_backbone.features.5.branch2.0.1.bias", "i_backbone.features.5.branch2.0.1.running_mean", "i_backbone.features.5.branch2.0.1.running_var", "i_backbone.features.5.branch2.1.0.0.weight", "i_backbone.features.5.branch2.1.0.1.weight", "i_backbone.features.5.branch2.1.0.1.bias", "i_backbone.features.5.branch2.1.0.1.running_mean", "i_backbone.features.5.branch2.1.0.1.running_var", "i_backbone.features.5.branch2.1.1.0.weight", "i_backbone.features.5.branch2.1.1.1.weight", "i_backbone.features.5.branch2.1.1.1.bias", "i_backbone.features.5.branch2.1.1.1.running_mean", "i_backbone.features.5.branch2.1.1.1.running_var", "i_backbone.features.5.branch3.1.0.weight", "i_backbone.features.5.branch3.1.1.weight", "i_backbone.features.5.branch3.1.1.bias", "i_backbone.features.5.branch3.1.1.running_mean", "i_backbone.features.5.branch3.1.1.running_var", "i_backbone.features.6.branch0.0.weight", "i_backbone.features.6.branch0.1.weight", "i_backbone.features.6.branch0.1.bias", "i_backbone.features.6.branch0.1.running_mean", "i_backbone.features.6.branch0.1.running_var", "i_backbone.features.6.branch1.0.0.weight", "i_backbone.features.6.branch1.0.1.weight", "i_backbone.features.6.branch1.0.1.bias", "i_backbone.features.6.branch1.0.1.running_mean", "i_backbone.features.6.branch1.0.1.running_var", "i_backbone.features.6.branch1.1.0.0.weight", "i_backbone.features.6.branch1.1.0.1.weight", "i_backbone.features.6.branch1.1.0.1.bias", "i_backbone.features.6.branch1.1.0.1.running_mean", "i_backbone.features.6.branch1.1.0.1.running_var", "i_backbone.features.6.branch1.1.1.0.weight", "i_backbone.features.6.branch1.1.1.1.weight", "i_backbone.features.6.branch1.1.1.1.bias", "i_backbone.features.6.branch1.1.1.1.running_mean", "i_backbone.features.6.branch1.1.1.1.running_var", "i_backbone.features.6.branch2.0.0.weight", "i_backbone.features.6.branch2.0.1.weight", "i_backbone.features.6.branch2.0.1.bias", "i_backbone.features.6.branch2.0.1.running_mean", "i_backbone.features.6.branch2.0.1.running_var", "i_backbone.features.6.branch2.1.0.0.weight", "i_backbone.features.6.branch2.1.0.1.weight", "i_backbone.features.6.branch2.1.0.1.bias", "i_backbone.features.6.branch2.1.0.1.running_mean", "i_backbone.features.6.branch2.1.0.1.running_var", "i_backbone.features.6.branch2.1.1.0.weight", "i_backbone.features.6.branch2.1.1.1.weight", "i_backbone.features.6.branch2.1.1.1.bias", "i_backbone.features.6.branch2.1.1.1.running_mean", "i_backbone.features.6.branch2.1.1.1.running_var", "i_backbone.features.6.branch3.1.0.weight", "i_backbone.features.6.branch3.1.1.weight", "i_backbone.features.6.branch3.1.1.bias", "i_backbone.features.6.branch3.1.1.running_mean", "i_backbone.features.6.branch3.1.1.running_var", "i_backbone.features.8.branch0.0.weight", "i_backbone.features.8.branch0.1.weight", "i_backbone.features.8.branch0.1.bias", "i_backbone.features.8.branch0.1.running_mean", "i_backbone.features.8.branch0.1.running_var", "i_backbone.features.8.branch1.0.0.weight", "i_backbone.features.8.branch1.0.1.weight", "i_backbone.features.8.branch1.0.1.bias", "i_backbone.features.8.branch1.0.1.running_mean", "i_backbone.features.8.branch1.0.1.running_var", "i_backbone.features.8.branch1.1.0.0.weight", "i_backbone.features.8.branch1.1.0.1.weight", "i_backbone.features.8.branch1.1.0.1.bias", "i_backbone.features.8.branch1.1.0.1.running_mean", "i_backbone.features.8.branch1.1.0.1.running_var", "i_backbone.features.8.branch1.1.1.0.weight", "i_backbone.features.8.branch1.1.1.1.weight", "i_backbone.features.8.branch1.1.1.1.bias", "i_backbone.features.8.branch1.1.1.1.running_mean", "i_backbone.features.8.branch1.1.1.1.running_var", "i_backbone.features.8.branch2.0.0.weight", "i_backbone.features.8.branch2.0.1.weight", "i_backbone.features.8.branch2.0.1.bias", "i_backbone.features.8.branch2.0.1.running_mean", "i_backbone.features.8.branch2.0.1.running_var", "i_backbone.features.8.branch2.1.0.0.weight", "i_backbone.features.8.branch2.1.0.1.weight", "i_backbone.features.8.branch2.1.0.1.bias", "i_backbone.features.8.branch2.1.0.1.running_mean", "i_backbone.features.8.branch2.1.0.1.running_var", "i_backbone.features.8.branch2.1.1.0.weight", "i_backbone.features.8.branch2.1.1.1.weight", "i_backbone.features.8.branch2.1.1.1.bias", "i_backbone.features.8.branch2.1.1.1.running_mean", "i_backbone.features.8.branch2.1.1.1.running_var", "i_backbone.features.8.branch3.1.0.weight", "i_backbone.features.8.branch3.1.1.weight", "i_backbone.features.8.branch3.1.1.bias", "i_backbone.features.8.branch3.1.1.running_mean", "i_backbone.features.8.branch3.1.1.running_var", "i_backbone.features.9.branch0.0.weight", "i_backbone.features.9.branch0.1.weight", "i_backbone.features.9.branch0.1.bias", "i_backbone.features.9.branch0.1.running_mean", "i_backbone.features.9.branch0.1.running_var", "i_backbone.features.9.branch1.0.0.weight", "i_backbone.features.9.branch1.0.1.weight", "i_backbone.features.9.branch1.0.1.bias", "i_backbone.features.9.branch1.0.1.running_mean", "i_backbone.features.9.branch1.0.1.running_var", "i_backbone.features.9.branch1.1.0.0.weight", "i_backbone.features.9.branch1.1.0.1.weight", "i_backbone.features.9.branch1.1.0.1.bias", "i_backbone.features.9.branch1.1.0.1.running_mean", "i_backbone.features.9.branch1.1.0.1.running_var", "i_backbone.features.9.branch1.1.1.0.weight", "i_backbone.features.9.branch1.1.1.1.weight", "i_backbone.features.9.branch1.1.1.1.bias", "i_backbone.features.9.branch1.1.1.1.running_mean", "i_backbone.features.9.branch1.1.1.1.running_var", "i_backbone.features.9.branch2.0.0.weight", "i_backbone.features.9.branch2.0.1.weight", "i_backbone.features.9.branch2.0.1.bias", "i_backbone.features.9.branch2.0.1.running_mean", "i_backbone.features.9.branch2.0.1.running_var", "i_backbone.features.9.branch2.1.0.0.weight", "i_backbone.features.9.branch2.1.0.1.weight", "i_backbone.features.9.branch2.1.0.1.bias", "i_backbone.features.9.branch2.1.0.1.running_mean", "i_backbone.features.9.branch2.1.0.1.running_var", "i_backbone.features.9.branch2.1.1.0.weight", "i_backbone.features.9.branch2.1.1.1.weight", "i_backbone.features.9.branch2.1.1.1.bias", "i_backbone.features.9.branch2.1.1.1.running_mean", "i_backbone.features.9.branch2.1.1.1.running_var", "i_backbone.features.9.branch3.1.0.weight", "i_backbone.features.9.branch3.1.1.weight", "i_backbone.features.9.branch3.1.1.bias", "i_backbone.features.9.branch3.1.1.running_mean", "i_backbone.features.9.branch3.1.1.running_var", "i_backbone.features.10.branch0.0.weight", "i_backbone.features.10.branch0.1.weight", "i_backbone.features.10.branch0.1.bias", "i_backbone.features.10.branch0.1.running_mean", "i_backbone.features.10.branch0.1.running_var", "i_backbone.features.10.branch1.0.0.weight", "i_backbone.features.10.branch1.0.1.weight", "i_backbone.features.10.branch1.0.1.bias", "i_backbone.features.10.branch1.0.1.running_mean", "i_backbone.features.10.branch1.0.1.running_var", "i_backbone.features.10.branch1.1.0.0.weight", "i_backbone.features.10.branch1.1.0.1.weight", "i_backbone.features.10.branch1.1.0.1.bias", "i_backbone.features.10.branch1.1.0.1.running_mean", "i_backbone.features.10.branch1.1.0.1.running_var", "i_backbone.features.10.branch1.1.1.0.weight", "i_backbone.features.10.branch1.1.1.1.weight", "i_backbone.features.10.branch1.1.1.1.bias", "i_backbone.features.10.branch1.1.1.1.running_mean", "i_backbone.features.10.branch1.1.1.1.running_var", "i_backbone.features.10.branch2.0.0.weight", "i_backbone.features.10.branch2.0.1.weight", "i_backbone.features.10.branch2.0.1.bias", "i_backbone.features.10.branch2.0.1.running_mean", "i_backbone.features.10.branch2.0.1.running_var", "i_backbone.features.10.branch2.1.0.0.weight", "i_backbone.features.10.branch2.1.0.1.weight", "i_backbone.features.10.branch2.1.0.1.bias", "i_backbone.features.10.branch2.1.0.1.running_mean", "i_backbone.features.10.branch2.1.0.1.running_var", "i_backbone.features.10.branch2.1.1.0.weight", "i_backbone.features.10.branch2.1.1.1.weight", "i_backbone.features.10.branch2.1.1.1.bias", "i_backbone.features.10.branch2.1.1.1.running_mean", "i_backbone.features.10.branch2.1.1.1.running_var", "i_backbone.features.10.branch3.1.0.weight", "i_backbone.features.10.branch3.1.1.weight", "i_backbone.features.10.branch3.1.1.bias", "i_backbone.features.10.branch3.1.1.running_mean", "i_backbone.features.10.branch3.1.1.running_var", "i_backbone.features.11.branch0.0.weight", "i_backbone.features.11.branch0.1.weight", "i_backbone.features.11.branch0.1.bias", "i_backbone.features.11.branch0.1.running_mean", "i_backbone.features.11.branch0.1.running_var", "i_backbone.features.11.branch1.0.0.weight", "i_backbone.features.11.branch1.0.1.weight", "i_backbone.features.11.branch1.0.1.bias", "i_backbone.features.11.branch1.0.1.running_mean", "i_backbone.features.11.branch1.0.1.running_var", "i_backbone.features.11.branch1.1.0.0.weight", "i_backbone.features.11.branch1.1.0.1.weight", "i_backbone.features.11.branch1.1.0.1.bias", "i_backbone.features.11.branch1.1.0.1.running_mean", "i_backbone.features.11.branch1.1.0.1.running_var", "i_backbone.features.11.branch1.1.1.0.weight", "i_backbone.features.11.branch1.1.1.1.weight", "i_backbone.features.11.branch1.1.1.1.bias", "i_backbone.features.11.branch1.1.1.1.running_mean", "i_backbone.features.11.branch1.1.1.1.running_var", "i_backbone.features.11.branch2.0.0.weight", "i_backbone.features.11.branch2.0.1.weight", "i_backbone.features.11.branch2.0.1.bias", "i_backbone.features.11.branch2.0.1.running_mean", "i_backbone.features.11.branch2.0.1.running_var", "i_backbone.features.11.branch2.1.0.0.weight", "i_backbone.features.11.branch2.1.0.1.weight", "i_backbone.features.11.branch2.1.0.1.bias", "i_backbone.features.11.branch2.1.0.1.running_mean", "i_backbone.features.11.branch2.1.0.1.running_var", "i_backbone.features.11.branch2.1.1.0.weight", "i_backbone.features.11.branch2.1.1.1.weight", "i_backbone.features.11.branch2.1.1.1.bias", "i_backbone.features.11.branch2.1.1.1.running_mean", "i_backbone.features.11.branch2.1.1.1.running_var", "i_backbone.features.11.branch3.1.0.weight", "i_backbone.features.11.branch3.1.1.weight", "i_backbone.features.11.branch3.1.1.bias", "i_backbone.features.11.branch3.1.1.running_mean", "i_backbone.features.11.branch3.1.1.running_var", "i_backbone.features.12.branch0.0.weight", "i_backbone.features.12.branch0.1.weight", "i_backbone.features.12.branch0.1.bias", "i_backbone.features.12.branch0.1.running_mean", "i_backbone.features.12.branch0.1.running_var", "i_backbone.features.12.branch1.0.0.weight", "i_backbone.features.12.branch1.0.1.weight", "i_backbone.features.12.branch1.0.1.bias", "i_backbone.features.12.branch1.0.1.running_mean", "i_backbone.features.12.branch1.0.1.running_var", "i_backbone.features.12.branch1.1.0.0.weight", "i_backbone.features.12.branch1.1.0.1.weight", "i_backbone.features.12.branch1.1.0.1.bias", "i_backbone.features.12.branch1.1.0.1.running_mean", "i_backbone.features.12.branch1.1.0.1.running_var", "i_backbone.features.12.branch1.1.1.0.weight", "i_backbone.features.12.branch1.1.1.1.weight", "i_backbone.features.12.branch1.1.1.1.bias", "i_backbone.features.12.branch1.1.1.1.running_mean", "i_backbone.features.12.branch1.1.1.1.running_var", "i_backbone.features.12.branch2.0.0.weight", "i_backbone.features.12.branch2.0.1.weight", "i_backbone.features.12.branch2.0.1.bias", "i_backbone.features.12.branch2.0.1.running_mean", "i_backbone.features.12.branch2.0.1.running_var", "i_backbone.features.12.branch2.1.0.0.weight", "i_backbone.features.12.branch2.1.0.1.weight", "i_backbone.features.12.branch2.1.0.1.bias", "i_backbone.features.12.branch2.1.0.1.running_mean", "i_backbone.features.12.branch2.1.0.1.running_var", "i_backbone.features.12.branch2.1.1.0.weight", "i_backbone.features.12.branch2.1.1.1.weight", "i_backbone.features.12.branch2.1.1.1.bias", "i_backbone.features.12.branch2.1.1.1.running_mean", "i_backbone.features.12.branch2.1.1.1.running_var", "i_backbone.features.12.branch3.1.0.weight", "i_backbone.features.12.branch3.1.1.weight", "i_backbone.features.12.branch3.1.1.bias", "i_backbone.features.12.branch3.1.1.running_mean", "i_backbone.features.12.branch3.1.1.running_var", "i_backbone.features.14.branch0.0.weight", "i_backbone.features.14.branch0.1.weight", "i_backbone.features.14.branch0.1.bias", "i_backbone.features.14.branch0.1.running_mean", "i_backbone.features.14.branch0.1.running_var", "i_backbone.features.14.branch1.0.0.weight", "i_backbone.features.14.branch1.0.1.weight", "i_backbone.features.14.branch1.0.1.bias", "i_backbone.features.14.branch1.0.1.running_mean", "i_backbone.features.14.branch1.0.1.running_var", "i_backbone.features.14.branch1.1.0.0.weight", "i_backbone.features.14.branch1.1.0.1.weight", "i_backbone.features.14.branch1.1.0.1.bias", "i_backbone.features.14.branch1.1.0.1.running_mean", "i_backbone.features.14.branch1.1.0.1.running_var", "i_backbone.features.14.branch1.1.1.0.weight", "i_backbone.features.14.branch1.1.1.1.weight", "i_backbone.features.14.branch1.1.1.1.bias", "i_backbone.features.14.branch1.1.1.1.running_mean", "i_backbone.features.14.branch1.1.1.1.running_var", "i_backbone.features.14.branch2.0.0.weight", "i_backbone.features.14.branch2.0.1.weight", "i_backbone.features.14.branch2.0.1.bias", "i_backbone.features.14.branch2.0.1.running_mean", "i_backbone.features.14.branch2.0.1.running_var", "i_backbone.features.14.branch2.1.0.0.weight", "i_backbone.features.14.branch2.1.0.1.weight", "i_backbone.features.14.branch2.1.0.1.bias", "i_backbone.features.14.branch2.1.0.1.running_mean", "i_backbone.features.14.branch2.1.0.1.running_var", "i_backbone.features.14.branch2.1.1.0.weight", "i_backbone.features.14.branch2.1.1.1.weight", "i_backbone.features.14.branch2.1.1.1.bias", "i_backbone.features.14.branch2.1.1.1.running_mean", "i_backbone.features.14.branch2.1.1.1.running_var", "i_backbone.features.14.branch3.1.0.weight", "i_backbone.features.14.branch3.1.1.weight", "i_backbone.features.14.branch3.1.1.bias", "i_backbone.features.14.branch3.1.1.running_mean", "i_backbone.features.14.branch3.1.1.running_var", "i_backbone.features.15.branch0.0.weight", "i_backbone.features.15.branch0.1.weight", "i_backbone.features.15.branch0.1.bias", "i_backbone.features.15.branch0.1.running_mean", "i_backbone.features.15.branch0.1.running_var", "i_backbone.features.15.branch1.0.0.weight", "i_backbone.features.15.branch1.0.1.weight", "i_backbone.features.15.branch1.0.1.bias", "i_backbone.features.15.branch1.0.1.running_mean", "i_backbone.features.15.branch1.0.1.running_var", "i_backbone.features.15.branch1.1.0.0.weight", "i_backbone.features.15.branch1.1.0.1.weight", "i_backbone.features.15.branch1.1.0.1.bias", "i_backbone.features.15.branch1.1.0.1.running_mean", "i_backbone.features.15.branch1.1.0.1.running_var", "i_backbone.features.15.branch1.1.1.0.weight", "i_backbone.features.15.branch1.1.1.1.weight", "i_backbone.features.15.branch1.1.1.1.bias", "i_backbone.features.15.branch1.1.1.1.running_mean", "i_backbone.features.15.branch1.1.1.1.running_var", "i_backbone.features.15.branch2.0.0.weight", "i_backbone.features.15.branch2.0.1.weight", "i_backbone.features.15.branch2.0.1.bias", "i_backbone.features.15.branch2.0.1.running_mean", "i_backbone.features.15.branch2.0.1.running_var", "i_backbone.features.15.branch2.1.0.0.weight", "i_backbone.features.15.branch2.1.0.1.weight", "i_backbone.features.15.branch2.1.0.1.bias", "i_backbone.features.15.branch2.1.0.1.running_mean", "i_backbone.features.15.branch2.1.0.1.running_var", "i_backbone.features.15.branch2.1.1.0.weight", "i_backbone.features.15.branch2.1.1.1.weight", "i_backbone.features.15.branch2.1.1.1.bias", "i_backbone.features.15.branch2.1.1.1.running_mean", "i_backbone.features.15.branch2.1.1.1.running_var", "i_backbone.features.15.branch3.1.0.weight", "i_backbone.features.15.branch3.1.1.weight", "i_backbone.features.15.branch3.1.1.bias", "i_backbone.features.15.branch3.1.1.running_mean", "i_backbone.features.15.branch3.1.1.running_var". 
	Unexpected key(s) in state_dict: "k_backbone.0.weight", "k_backbone.0.bias", "k_backbone.2.weight", "k_backbone.2.bias", "k_backbone.4.weight", "k_backbone.4.bias". 
	size mismatch for classifier.1.weight: copying a param with shape torch.Size([6, 288]) from checkpoint, the shape in current model is torch.Size([6, 1024]).

In [5]:
trainer.fit(model = model,train_dataloaders= train_dl)


  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type             | Params
------------------------------------------------
0 | k_backbone | Sequential       | 163 K 
1 | i_backbone | S3D              | 8.7 M 
2 | classifier | Sequential       | 6.2 K 
3 | loss       | CrossEntropyLoss | 0     
------------------------------------------------
8.9 M     Trainable params
0         Non-trainable params
8.9 M     Total params
35.559    Total estimated model params size (MB)
  rank_zero_warn(


Epoch 0:   0%|          | 0/358 [00:00<?, ?it/s] 

  i_inputs = self.transform(torch.tensor(batch_dict[input], dtype=torch.float32)/255.)


Epoch 0:  13%|█▎        | 45/358 [02:29<17:16,  3.31s/it, loss=1.54, v_num=5]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [61]:
backbone = s3d(weights=None)
batch1 = torch.randn(5, 3, 16, 224, 224)
batch2 = torch.randn(5, 3, 16, 224, 224)
batch3 = torch.randn(5, 36, 16)
expanded_batch3 = batch3.unsqueeze(3).unsqueeze(4).repeat(1,1,1,224,224)
expanded_batch3.shape

torch.Size([5, 36, 16, 224, 224])

In [42]:
expanded_batch3[0][1][1]

tensor([[-0.1213, -0.1213, -0.1213,  ..., -0.1213, -0.1213, -0.1213],
        [-0.1213, -0.1213, -0.1213,  ..., -0.1213, -0.1213, -0.1213],
        [-0.1213, -0.1213, -0.1213,  ..., -0.1213, -0.1213, -0.1213],
        ...,
        [-0.1213, -0.1213, -0.1213,  ..., -0.1213, -0.1213, -0.1213],
        [-0.1213, -0.1213, -0.1213,  ..., -0.1213, -0.1213, -0.1213],
        [-0.1213, -0.1213, -0.1213,  ..., -0.1213, -0.1213, -0.1213]])

In [70]:
# fused = torch.cat([batch1, batch2], dim=1)
fused = torch.cat([batch1, batch2, expanded_batch3], dim=1)
fused.shape

torch.Size([5, 42, 16, 224, 224])

In [62]:
backbone.features[0]

TemporalSeparableConv(
  (0): Conv3dNormActivation(
    (0): Conv3d(3, 64, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3), bias=False)
    (1): BatchNorm3d(64, eps=0.001, momentum=0.001, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (1): Conv3dNormActivation(
    (0): Conv3d(64, 64, kernel_size=(7, 1, 1), stride=(2, 1, 1), padding=(3, 0, 0), bias=False)
    (1): BatchNorm3d(64, eps=0.001, momentum=0.001, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
)

In [63]:
backbone.features[0][0][0] = nn.Conv3d(42, 128, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False)
backbone.features[0][0][1] = nn.BatchNorm3d(128, eps=0.001, momentum=0.001, affine=True, track_running_stats=True)
backbone.features[0][1][0] = nn.Conv3d(128, 64, kernel_size=(7, 1, 1), stride=(2, 1, 1), padding=(3, 0, 0), bias=False)

backbone.classifier = nn.Identity()

In [64]:
backbone(fused).shape

torch.Size([5, 1024])