In [1]:
try:
    from key_utils import KeySegmentDataModule, KeyClf, id2label, label2id
except:
    import sys
    sys.path.append("/kaggle/input/keystroke-util")
    from key_utils import KeySegmentDataModule, KeyClf, id2label, label2id

from transformers import VivitImageProcessor, VivitForVideoClassification
from lightning.pytorch.callbacks import EarlyStopping
import torchvision
import torchvision.transforms.functional
import lightning as L
import torch

image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")

def preprocess(frames): 
    out = image_processor(list(frames), return_tensors="pt")
    pixel_values = out['pixel_values'][0]
    return pixel_values


def transforms(frames):
    h, w = frames.shape[-2], frames.shape[-1]

    frames = torchvision.transforms.functional.resized_crop(
                              frames, 
                              top=h//2, 
                              left=0, 
                              height=h//2,
                              width=w,
                              size=(224, 224))
    
    out = image_processor(list(frames), return_tensors="pt")
    pixel_values = out['pixel_values'][0]
    return pixel_values
    
dm = KeySegmentDataModule(segment_dir='datasets/angle/segments_dir', 
                          num_workers=0,
                          transforms=transforms)
weights = dm.train_weights


class ViVitKeyClf(KeyClf):
    def __init__(self, learning_rate=0.01):
        super().__init__(weights, learning_rate)
        self.model = VivitForVideoClassification.from_pretrained(
            "google/vivit-b-16x2-kinetics400", 
            id2label=id2label,
            label2id=label2id,
            ignore_mismatched_sizes=True,
            num_frames=8,
        )

    def forward(self, batch):
        videos, targets = batch
        out = self.model(videos)
        preds = out.logits
        loss = self.loss_fn(preds, targets.long())
        pred_ids = torch.argmax(preds, dim=1)
        return loss, pred_ids



module = ViVitKeyClf()
trainer = L.Trainer(
    # deterministic=True,
    # devices=[0, 1],
    # accelerator="gpu",
    fast_dev_run=False,
    max_epochs=100,
    callbacks=EarlyStopping(monitor='val_loss', patience=5),
)

trainer.fit(module, dm)
# trainer.test(module, dm)

  from .autonotebook import tqdm as notebook_tqdm


Train:
 Counter({'idle': 2715, 'space': 917, 'e': 497, 'BackSpace': 429, 'i': 328, 'a': 320, 'o': 302, 't': 289, 'r': 250, 'n': 246, 's': 215, 'u': 184, 'l': 183, 'h': 162, 'd': 159, 'c': 155, 'y': 119, 'g': 109, 'm': 109, 'p': 108, 'w': 103, 'b': 91, 'k': 86, 'f': 85, 'dot': 84, 'v': 73, 'comma': 66, 'j': 62, 'z': 58, 'x': 54, 'q': 52})
Val:
 Counter({'idle': 973, 'space': 311, 'e': 162, 'BackSpace': 136, 'i': 112, 'a': 108, 't': 96, 'o': 87, 'n': 75, 'r': 73, 'h': 67, 's': 62, 'u': 57, 'l': 52, 'c': 49, 'd': 49, 'f': 43, 'y': 41, 'm': 39, 'g': 38, 'w': 28, 'p': 26, 'comma': 26, 'b': 26, 'z': 24, 'dot': 23, 'v': 22, 'x': 18, 'k': 16, 'j': 16, 'q': 15})
Test:
 Counter({'idle': 944, 'space': 316, 'BackSpace': 164, 'e': 147, 'i': 112, 't': 103, 'o': 94, 'n': 85, 'r': 84, 'a': 80, 's': 64, 'l': 62, 'c': 57, 'u': 54, 'm': 53, 'd': 46, 'h': 42, 'w': 37, 'f': 36, 'y': 36, 'g': 33, 'p': 31, 'b': 31, 'dot': 28, 'k': 23, 'v': 23, 'q': 21, 'comma': 20, 'x': 20, 'z': 13, 'j': 11})
train_weights: 

Some weights of VivitForVideoClassification were not initialized from the model checkpoint at google/vivit-b-16x2-kinetics400 and are newly initialized because the shapes did not match:
- vivit.embeddings.position_embeddings: found shape torch.Size([1, 3137, 768]) in the checkpoint and torch.Size([1, 785, 768]) in the model instantiated
- classifier.weight: found shape torch.Size([400, 768]) in the checkpoint and torch.Size([31, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([400]) in the checkpoint and torch.Size([31]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name      | Type                        | Params | Mode 
------------------------------------------------------------------
0 | loss_fn   | CrossEntropyLoss            | 0      | train
1 | t

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/haily/.pyenv/versions/3.10.4/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
  return torch.tensor(value)


                                                                           

/Users/haily/.pyenv/versions/3.10.4/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Epoch 0:   0%|          | 1/2153 [00:10<6:00:53,  0.10it/s, v_num=12]