In [1]:
!pip install torch lightning torchvision pyav

Collecting lightning
  Downloading lightning-2.3.1-py3-none-any.whl.metadata (54 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.1/54.1 kB[0m [31m847.4 kB/s[0m eta [36m0:00:00[0m
Collecting pyav
  Downloading pyav-12.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.5 kB)
Downloading lightning-2.3.1-py3-none-any.whl (2.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyav-12.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m30.2/30.2 MB[0m [31m30.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pyav, lightning
Successfully installed lightning-2.3.1 pyav-12.1.0


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import math
from functools import partial
import pathlib
import torch
# from GesRec.models.resnet import resnet101
from torch.utils.data import DataLoader, Dataset
from torchvision.io import read_video
import lightning as L
from lightning.pytorch.loggers import CSVLogger
import torchmetrics
from lightning.pytorch.callbacks import EarlyStopping

# Training hyperparameters
IMG_SIZE = 112
FRAMES_PER_VIDEO = 8
NUM_CLASSES = 30
LEARNING_RATE = 0.001
BATCH_SIZE = 16
MAX_EPOCHS = 1000
MAX_TIME = "00:10:00:00"

# Dataset
DATA_DIR = "/kaggle/input/key-clf/key_clf_data_112_112"
NUM_WORKERS = 4

FAST_DEV_RUN = False
CHECKPOINT_DIR = "resnet/"

# Compute related
ACCELERATOR = "gpu"
DEVICES = [0,1]

def conv3x3x3(in_planes, out_planes, stride=1):
    # 3x3x3 convolution with padding
    return nn.Conv3d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=1,
        bias=False)


def downsample_basic_block(x, planes, stride):
    out = F.avg_pool3d(x, kernel_size=1, stride=stride)
    zero_pads = torch.Tensor(
        out.size(0), 
        planes - out.size(1), 
        out.size(2), 
        out.size(3),
        out.size(4)
    ).zero_()
    
    if isinstance(out.data, torch.cuda.FloatTensor):
        zero_pads = zero_pads.cuda()

    out = Variable(torch.cat([out.data, zero_pads], dim=1))

    return out


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm3d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3x3(planes, planes)
        self.bn2 = nn.BatchNorm3d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm3d(planes)

        self.conv2 = nn.Conv3d(
            planes, planes, 
            kernel_size=3, stride=stride, padding=1, 
            bias=False)
        self.bn2 = nn.BatchNorm3d(planes)
        
        self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm3d(planes * 4)
        
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self,
                 block,
                 layers,
                 sample_size,
                 sample_duration,
                 shortcut_type='B',
                 num_classes=400):
        
        """
        block: basic block or bottle neck
        layers: define Resnet architecture 34, 101, 152 etc
        sample size: image size
        shortcut_type: 'A' or 'B'
        num_classes: ...
        """
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv3d(
            3,
            64,
            kernel_size=7,
            stride=(1, 2, 2),
            padding=(3, 3, 3),
            bias=False)
        self.bn1 = nn.BatchNorm3d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)


        self.layer1 = self._make_layer(
            block, 64, layers[0], shortcut_type)
        self.layer2 = self._make_layer(
            block, 128, layers[1], shortcut_type, stride=2)
        self.layer3 = self._make_layer(
            block, 256, layers[2], shortcut_type, stride=2)
        self.layer4 = self._make_layer(
            block, 512, layers[3], shortcut_type, stride=2)
        
        
        last_duration = int(math.ceil(sample_duration / 16))
        last_size = int(math.ceil(sample_size / 32))
        self.avgpool = nn.AvgPool3d(
            (last_duration, last_size, last_size), stride=1)
        
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')
            elif isinstance(m, nn.BatchNorm3d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, shortcut_type, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            if shortcut_type == 'A':
                downsample = partial(
                    downsample_basic_block,
                    planes=planes * block.expansion,
                    stride=stride)
            else:
                downsample = nn.Sequential(
                    nn.Conv3d(
                        self.inplanes,
                        planes * block.expansion,
                        kernel_size=1,
                        stride=stride,
                        bias=False), nn.BatchNorm3d(planes * block.expansion))

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)

        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


def get_fine_tuning_parameters(model, ft_portion):
    if ft_portion == "complete":
        return model.parameters()

    elif ft_portion == "last_layer":
        ft_module_names = []
        ft_module_names.append('classifier')

        parameters = []
        for k, v in model.named_parameters():
            for ft_module in ft_module_names:
                if ft_module in k:
                    parameters.append({'params': v})
                    break
            else:
                parameters.append({'params': v, 'lr': 0.0})
        return parameters
    else:
        raise ValueError("Unsupported ft_portion: 'complete' or 'last_layer' expected")


def resnet10(**kwargs):
    """Constructs a ResNet-18 model.
    """
    model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs)
    return model


def resnet18(**kwargs):
    """Constructs a ResNet-18 model.
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    return model


def resnet34(**kwargs):
    """Constructs a ResNet-34 model.
    """
    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
    return model


def resnet50(**kwargs):
    """Constructs a ResNet-50 model.
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    return model


def resnet101(**kwargs):
    """Constructs a ResNet-101 model.
    """
    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
    return model


def resnet152(**kwargs):
    """Constructs a ResNet-101 model.
    """
    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
    return model


def resnet200(**kwargs):
    """Constructs a ResNet-101 model.
    """
    model = ResNet(Bottleneck, [3, 24, 36, 3], **kwargs)
    return model



class KeyClf(L.LightningModule):
    def __init__(self, img_size, frames_per_video, num_classes, learning_rate, weights):
        super().__init__()
        self.model = resnet101(sample_size=img_size,
                 sample_duration=frames_per_video,
                 shortcut_type='A',
                 num_classes=num_classes
                )
        
        self.loss_fn = torch.nn.CrossEntropyLoss(torch.tensor(weights))
        self.accuracy = torchmetrics.Accuracy(
            task="multiclass", num_classes=num_classes
        )

        # self.training_step_preds = []
        # self.training_step_targets = []
        # self.validate_step_preds = []
        # self.validate_step_targets = []

        self.lr = learning_rate


    def training_step(self, batch):
        videos, targets = batch
        preds = self.model(videos.to('cuda'))
        loss = self.loss_fn(preds, targets.long())
        # self.training_step_preds.append(preds)
        # self.training_step_targets.append(targets)

        self.log_dict(
            {
                "train_loss": loss,
                "train_acc": self.accuracy(preds, targets),
            },
            on_step=True,
            on_epoch=True,
            prog_bar=True,
        )

        return loss
    
    # def on_training_epoch_end(self):
    #     preds = torch.cat(self.training_step_preds)
    #     targets = torch.cat(self.training_step_targets)
  
    #     self.log_dict(
    #         {
    #             "train_acc": self.accuracy(preds, targets),
    #         },
    #         on_step=False,
    #         on_epoch=True,
    #         prog_bar=True,
    #         logger=True,
    #     )

    #     self.training_step_preds.clear()
    #     self.training_step_targets.clear()


    def validation_step(self, batch):
        videos, targets = batch
        preds = self.model(videos)
        loss = self.loss_fn(preds, targets.long())
        # self.validate_step_preds.append(preds)
        # self.validate_step_targets.append(targets)

        self.log_dict(
            {
                "val_loss": loss,
                "val_acc": self.accuracy(preds, targets),
            },
            on_step=False,
            on_epoch=True,
            prog_bar=True,
        )

        return loss
    
    
    # def on_validation_epoch_end(self):
    #     preds = torch.cat(self.validate_step_preds)
    #     targets = torch.cat(self.validate_step_targets)
       
    #     self.log_dict(
    #         {
    #             "val_acc": self.accuracy(preds, targets),
    #         },
    #         on_step=False,
    #         on_epoch=True,
    #         prog_bar=True,
    #         logger=True
    #     )

    #     self.validate_step_preds.clear()
    #     self.validate_step_targets.clear()
    
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=self.lr)
    
class KeyStrokeClsDataset(Dataset):
    def __init__(self, data_dir, mode):
        self.dataset_root_path = pathlib.Path(data_dir)
        self.all_video_file_paths =  list(self.dataset_root_path.glob(f"{mode}/*/*.mp4"))

        self.class_labels = sorted({str(path).split("/")[-2] for path in self.all_video_file_paths})
        
        self.label2id = {label: i for i, label in enumerate(self.class_labels)}
        self.id2label  = {i: label for label, i in self.label2id.items()}
  
    def __len__(self):
        return len(self.all_video_file_paths)

    def __getitem__(self, idx):
        file_path = self.all_video_file_paths[idx]
        vframes, _, _ = read_video(file_path, pts_unit='sec')
        label = str(file_path).split("/")[-2]

        # permute to (num_frames, num_channels, height, width)
        vframes = vframes.permute(3, 0, 1, 2).float() / 255.0
    
        return vframes, self.label2id[label]


class KeyClsData(L.LightningDataModule):
    def __init__(self, batch_size, data_dir):
        super().__init__()
        self.batch_size = batch_size
        self.dataset_root_path = pathlib.Path(data_dir)
        self.data_dir = data_dir

        train_video_file_paths =  list(self.dataset_root_path.glob(f"train/*/*.mp4"))

        class_labels = sorted({str(path).split("/")[-2] for path in train_video_file_paths})
        print('class_labels: ', class_labels)
        
        total = len(train_video_file_paths)
        weights = []
        for label in class_labels:
            samples = len(list(self.dataset_root_path.glob(f"train/{label}/*.mp4")))
            weights.append(total / (NUM_CLASSES * samples))
        self.weights = weights
        print('class_weights: ', weights)
   
        
    def train_dataloader(self):
        train_dataset = KeyStrokeClsDataset(self.data_dir, 'train')
        print("Train dataset:", len(train_dataset))
        return DataLoader(train_dataset, batch_size=self.batch_size, num_workers=NUM_WORKERS, persistent_workers=True)
    
    def val_dataloader(self):
        val_dataset = KeyStrokeClsDataset(self.data_dir, 'val')
        print("Val dataset:", len(val_dataset))
        return DataLoader(val_dataset, batch_size=self.batch_size, num_workers=NUM_WORKERS, persistent_workers=True)
    
    def test_dataloader(self):
        test_dataset = KeyStrokeClsDataset(self.data_dir, 'test')
        print("Test dataset:", len(test_dataset))
        return DataLoader(test_dataset, batch_size=self.batch_size, num_workers=NUM_WORKERS, persistent_workers=True)
    

if __name__ == '__main__':
    data = KeyClsData(batch_size=BATCH_SIZE, data_dir=DATA_DIR)
    model = KeyClf(img_size=IMG_SIZE, 
                frames_per_video=FRAMES_PER_VIDEO,
                num_classes=NUM_CLASSES, 
                learning_rate=LEARNING_RATE,
                weights = data.weights)

    logger = CSVLogger("logs", name=f"resnet", flush_logs_every_n_steps=1)
    trainer = L.Trainer(
        #deterministic=True,
        devices=DEVICES,
        max_time=MAX_TIME,
        callbacks=[EarlyStopping(monitor="val_loss", patience=5)],
        default_root_dir=CHECKPOINT_DIR,
        fast_dev_run=FAST_DEV_RUN,
        logger=logger,
        accelerator=ACCELERATOR
    )
    trainer.fit(model, data)

class_labels:  ['BackSpace', 'Comma', 'Space', 'Stop', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
class_weights:  [0.3786203010544957, 3.4687315634218288, 0.23390521656969515, 5.157456140350877, 0.535290770456356, 2.7603286384976524, 1.2979028697571744, 1.3481226712525078, 0.35877955758962626, 2.4195473251028807, 2.2689821514712976, 1.350445018662073, 0.5581583007001305, 4.544541062801932, 5.332879818594105, 0.9848408710217755, 1.8958484482063684, 0.6643502824858757, 0.6276487856952229, 1.7267254038179147, 5.619593787335723, 0.6949763593380615, 0.8534930139720559, 0.5532345330510469, 1.1135416666666667, 4.180977777777778, 2.9921119592875316, 3.637741686001547, 2.1389722601182357, 4.040893470790378]


  m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')
INFO: GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO: Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
INFO: Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
INFO: ----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------

INFO: LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1]
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
INFO: 
  | Name     | Type               | Params | Mode 
--------------------------------------------------------
0 | model    | ResNet             | 82.5 M | train
1 | loss_fn  | CrossEntropyLoss   | 0      | train
2 | accuracy | MulticlassAccuracy | 0      | train
------------------

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Val dataset:Val dataset:  74837483



/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:439: It is recommended to use `self.log('val_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:439: It is recommended to use `self.log('val_acc', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.


Train dataset: 47036
Train dataset: 47036


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:439: It is recommended to use `self.log('train_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:439: It is recommended to use `self.log('train_acc', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]