In [1]:
!pip install lightning torchvision

Collecting lightning
  Downloading lightning-2.3.2-py3-none-any.whl.metadata (54 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.1/54.1 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
Downloading lightning-2.3.2-py3-none-any.whl (2.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m21.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: lightning
Successfully installed lightning-2.3.2


In [2]:
import torch
import torch.nn as nn
import lightning as L
import pandas as pd
import torchvision
import numpy as np
import torch.nn.functional as F 
from torch.autograd import Variable
import math
from functools import partial
import pathlib
from torch.utils.data import DataLoader, Dataset
from torchvision.io import read_video
import lightning as L
from lightning.pytorch.loggers import CSVLogger
import torchmetrics
from lightning.pytorch.callbacks import EarlyStopping
from sklearn.metrics import classification_report, confusion_matrix, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from torchvision.transforms import CenterCrop, v2
from datetime import datetime
import csv
import glob


In [3]:
id2Label = ['[i]', 'BackSpace', ',', '[s]', '.', 
            'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 
            'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 
            'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 
            'y', 'z']
label2Id  = {label: i for i, label in enumerate(id2Label)}
labels_dir = '/kaggle/input/keycls-test/test_data/labels'
videos_dir = '/kaggle/input/keycls-test/test_data/raw_frames'

NUM_WORKERS = 4

# Resnet

In [4]:
#### RESNET 3D #### 
def conv3x3x3(in_planes, out_planes, stride=1):
    # 3x3x3 convolution with padding
    return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


def downsample_basic_block(x, planes, stride):
    out = F.avg_pool3d(x, kernel_size=1, stride=stride)
    zero_pads = torch.Tensor(out.size(0), planes - out.size(1), out.size(2), out.size(3), out.size(4) ).zero_()
    
    if isinstance(out.data, torch.cuda.FloatStorage): zero_pads = zero_pads.cuda()
    out = Variable(torch.cat([out.data, zero_pads], dim=1))
    return out

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm3d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3x3(planes, planes)
        self.bn2 = nn.BatchNorm3d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm3d(planes)

        self.conv2 = nn.Conv3d(
            planes, planes, 
            kernel_size=3, stride=stride, padding=1, 
            bias=False)
        self.bn2 = nn.BatchNorm3d(planes)
        
        self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm3d(planes * 4)
        
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):
    def __init__(self, block, layers, sample_size, sample_duration, shortcut_type='B', num_classes=400):
        """
        block: basic block or bottle neck
        layers: define Resnet architecture 34, 101, 152 etc
        sample size: image size
        shortcut_type: 'A' or 'B'
        num_classes: ...
        """
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv3d(3, 64, kernel_size=7, stride=(1, 2, 2), padding=(3, 3, 3), bias=False)
        self.bn1 = nn.BatchNorm3d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type)
        self.layer2 = self._make_layer(block, 128, layers[1], shortcut_type, stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], shortcut_type, stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], shortcut_type, stride=2)
        
        
        last_duration = int(math.ceil(sample_duration / 16))
        last_size = int(math.ceil(sample_size / 32))
        self.avgpool = nn.AvgPool3d(
            (last_duration, last_size, last_size), stride=1)
        
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')
            elif isinstance(m, nn.BatchNorm3d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, shortcut_type, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            if shortcut_type == 'A':
                downsample = partial(
                    downsample_basic_block,
                    planes=planes * block.expansion,
                    stride=stride)
            else:
                downsample = nn.Sequential(
                    nn.Conv3d(self.inplanes,planes * block.expansion,kernel_size=1,stride=stride,bias=False), 
                    nn.BatchNorm3d(planes * block.expansion))

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)

        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

def resnet10(**kwargs): return ResNet(BasicBlock, [1, 1, 1, 1], **kwargs)
def resnet18(**kwargs): return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
def resnet34(**kwargs): return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
def resnet50(**kwargs): return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
def resnet101(**kwargs): return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
def resnet152(**kwargs): return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
def resnet200(**kwargs): return ResNet(Bottleneck, [3, 24, 36, 3], **kwargs)

# Lightning dataset

In [5]:
class KeyDataset(torch.utils.data.Dataset):
    def __init__(self, video_name):
        segments = []
        # Infer idle frames.
        df = pd.read_csv(f'{labels_dir}/{video_name}.csv')
        for index, row in df.iterrows():
            key_frame = int(row['Frame'])  # Frame number where key was pressed
            key_value = row['Key']  # Key pressed
            if key_value not in id2Label:
                key_value = '[s]'
            
            is_idle_before = False
            if index == 0:
                pos_start = max(key_frame - 3, 0)
                pos_end = key_frame + 4
                neg_start = 0
                neg_end = pos_start - 2
                is_idle_before = True
            else:
                prev_key_frame = df.iloc[index - 1]['Frame']
                pos_start = max(key_frame - 3, 0)
                pos_end = key_frame + 4
                prev_pos_window_end = prev_key_frame + 4
                if (pos_start - prev_pos_window_end) - 1 >= 8:
                    neg_start =  prev_pos_window_end + 2
                    neg_end = pos_start - 2
                    is_idle_before = True
            
            # Negative class video segments before
            if is_idle_before:
                j = neg_start
                while (j + 7) <= neg_end:
                    segments.append(([j, j + 7], "[i]"))
                    j += 8
            
            # Current video with keystroke
            segments.append(([pos_start, pos_end], key_value))
        
        self.video_name = video_name
        self.segments = segments
    def __len__(self):
        return len(self.segments)

    def __getitem__(self, idx):
        (start, end), label = self.segments[idx]
        
        frames = []
        for i in range(start, end + 1):
            image = torchvision.io.read_image(f"{videos_dir}/{self.video_name}/frame_{i}.jpg")
            
            frames.append(image)
       
        return torch.stack(frames), label2Id[label]
    
    def get_class_counts(self):
        labels = [segment[1] for segment in self.segments]
        unique_elements, counts = np.unique(labels, return_counts=True)
        occurrences = dict(zip(unique_elements, counts))
        weights = np.zeros(len(id2Label))
        for label, count in occurrences.items():
            weights[label2Id[label]] = count
        return weights

class KeyDataModule(L.LightningDataModule):
    def __init__(self, batch_size, train_vids, val_vids, test_vids):
        super().__init__()
        self.batch_size = batch_size
        self.train_datasets = [KeyDataset(video_name) for video_name in train_vids]
        self.val_datasets = [KeyDataset(video_name) for video_name in val_vids]
        self.test_datasets = [KeyDataset(video_name) for video_name in test_vids]
        
        self.train_dataset = torch.utils.data.ConcatDataset(self.train_datasets)
        self.test_dataset = torch.utils.data.ConcatDataset(self.test_datasets)
        self.val_dataset = torch.utils.data.ConcatDataset(self.val_datasets)
        
        print(f"Train: {len(self.train_dataset)}; Val: {len(self.val_dataset)}; Test: {len(self.test_dataset)}")
        
        train_counts = np.array(
            [d.get_class_counts() for d in self.train_datasets]).sum(axis=0)
        print(f"Train counts: {train_counts}")
        train_total_samples = np.array([len(d) for d in self.train_datasets]).sum(axis=0)
        self.train_weights = train_counts / (train_total_samples * len(id2Label))
                                        
    def train_dataloader(self):
        return DataLoader(self.train_dataset, 
                          batch_size=self.batch_size, 
                          num_workers=NUM_WORKERS)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset, 
                          batch_size=self.batch_size, 
                          num_workers=NUM_WORKERS,
                          shuffle=False)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, 
                          batch_size=self.batch_size, 
                          num_workers=NUM_WORKERS,
                          shuffle=False)

class KeyClf(L.LightningModule):
    def __init__(self, img_size, num_classes, learning_rate, weights):
        super().__init__()
        self.model = resnet101(sample_size=img_size, 
                               sample_duration=8,
                               shortcut_type='B', 
                               num_classes=num_classes)
        
        self.loss_fn = torch.nn.CrossEntropyLoss(torch.tensor(weights).float())
        self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
        self.lr = learning_rate
        self.transforms = v2.Compose([
            v2.CenterCrop(img_size),
            v2.ToDtype(torch.float32, scale=True),
        ])
        
        self.test_preds = []
        self.test_targets = []
        self.save_hyperparameters()


    def test_step(self, batch):
        videos, targets = batch
        videos = self.transforms(videos)
        videos = videos.permute(0, 2, 1, 3, 4)
        preds = self.model(videos)

        pred_ids = torch.argmax(self.model(videos), dim=1).squeeze()
        pred_labels = [id2Label[_id] for _id in pred_ids]
        self.test_preds += pred_labels
        self.test_targets += [id2Label[_id] for _id in targets]
        
        loss = self.loss_fn(preds, targets.long())
        self.log_dict({'test_acc': self.accuracy(preds, targets), 'test_loss': loss})
    
    def on_test_end(self):
        print(classification_report(self.test_targets, self.test_preds))
        
    def training_step(self, batch):
        videos, targets = batch
        videos = self.transforms(videos)
        videos = videos.permute(0, 2, 1, 3, 4)
        preds = self.model(videos)
        loss = self.loss_fn(preds, targets.long())
        self.log_dict({"train_loss": loss, "train_acc": self.accuracy(preds, targets)})
        return loss

    def validation_step(self, batch):
        videos, targets = batch
        videos = self.transforms(videos)
        videos = videos.permute(0, 2, 1, 3, 4)
        preds = self.model(videos)
        loss = self.loss_fn(preds, targets.long())
        self.log_dict({"val_loss": loss, "val_acc": self.accuracy(preds, targets)})
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=self.lr)

## Train

In [6]:
# dm = KeyDataModule(batch_size=8, 
#                    train_vids=['video_1', 'video_2', 'video_3', 'video_4', 'video_5', 'video_6', 'video_7', 'video_8', 'video_9', 'video_10',
#                               'video_11', 'video_12', 'video_13', 'video_14', 'video_15', 'video_16', 'video_17', 'video_18', 'video_19',
#                               'video_21', 'video_22', 'video_23', 'video_24', 'video_25', 'video_26', 'video_27', 'video_28', 'video_29', 'video_30',
#                               'video_31', 'video_32', 'video_33', 'video_34', 'video_35', 'video_36'], 
#                    val_vids=['video_32', 'video_33', 'video_34'], 
#                    test_vids=['video_35', 'video_36'])
# model = KeyClf(img_size=320, num_classes=len(id2Label), learning_rate=0.001, weights=dm.train_weights)
# trainer = L.Trainer(
#     # deterministic=True,
#     devices=[0, 1],
#     max_time="00:11:00:00",
#     callbacks=[EarlyStopping(monitor="val_loss", patience=10)],
#     fast_dev_run=False,
#     accelerator="gpu")
# trainer.fit(model, dm)
# trainer.test(model, dm)

## Test

In [7]:
trained_model = KeyClf.load_from_checkpoint("/kaggle/input/keyclf/pytorch/v1/1/epoch7-step34979.ckpt")
video = 'video_33'
images = glob.glob(f'/kaggle/input/keycls-test/test_data/raw_frames/{video}/*.jpg')

frames = []
for i in range(len(images)):
    image = torchvision.io.read_image(f"/kaggle/input/keycls-test/test_data/raw_frames/{video}/frame_{i}.jpg")
    frames.append(image)
frames = torch.stack(frames)


preds = []
trained_model.freeze()
for i in range(0, len(frames)// 8):
    video = trained_model.transforms(frames[i: i + 8])
    video = video.permute(1, 0, 2, 3)
    out = F.softmax(trained_model.model(video.unsqueeze(0)), 1)[0]
    _id = torch.argmax(out).item()
    label = id2Label[_id]
    
    print(f"Frame {i}, {label}, {out[_id]}")
    preds.append((label, out[_id]))

  m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')


Frame 0, [i], 0.992448091506958
Frame 1, [i], 0.9974133372306824
Frame 2, [i], 0.9986820816993713
Frame 3, [i], 0.9989448189735413
Frame 4, [i], 0.9992009997367859
Frame 5, [i], 0.9991810917854309
Frame 6, [i], 0.9992801547050476
Frame 7, [i], 0.9989452958106995
Frame 8, [i], 0.9984174966812134
Frame 9, [i], 0.9987190961837769
Frame 10, [i], 0.9988056421279907
Frame 11, [i], 0.9984347224235535
Frame 12, [i], 0.9980826377868652
Frame 13, [i], 0.9971950054168701
Frame 14, [i], 0.9960976839065552
Frame 15, [i], 0.9978142976760864
Frame 16, [i], 0.9990190267562866
Frame 17, [i], 0.9991820454597473
Frame 18, [i], 0.9993066787719727
Frame 19, [i], 0.9987260699272156
Frame 20, [i], 0.9972108006477356
Frame 21, [i], 0.9962981343269348
Frame 22, [i], 0.9960047602653503
Frame 23, [i], 0.9984738230705261
Frame 24, [i], 0.9983382225036621
Frame 25, [i], 0.9941334128379822
Frame 26, [i], 0.9795316457748413
Frame 27, [i], 0.9123054146766663
Frame 28, [i], 0.845310628414154
Frame 29, [i], 0.802419722