In [1]:
!pip install torch lightning torchvision pyav

Collecting lightning
  Downloading lightning-2.3.1-py3-none-any.whl.metadata (54 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.1/54.1 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
Collecting pyav
  Downloading pyav-12.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.5 kB)
Downloading lightning-2.3.1-py3-none-any.whl (2.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m36.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyav-12.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m30.2/30.2 MB[0m [31m53.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pyav, lightning
Successfully installed lightning-2.3.1 pyav-12.1.0


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import math
from functools import partial
import pathlib
from torch.utils.data import DataLoader, Dataset
from torchvision.io import read_video
import lightning as L
from lightning.pytorch.loggers import CSVLogger
import torchmetrics
from lightning.pytorch.callbacks import EarlyStopping
from sklearn.metrics import classification_report, confusion_matrix, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from resnet import resnet101
import numpy as np
LOCAL = True

# Training hyperparameters
IMG_SIZE = 320
FRAMES_PER_VIDEO = 8
NUM_CLASSES = 30
LEARNING_RATE = 0.001
BATCH_SIZE = 16
MAX_EPOCHS = 10000
MAX_TIME = "00:11:00:00"

# Datasets
LOCAL_DATA_DIR = f"./datasets/key_clf_data_{IMG_SIZE}_{IMG_SIZE}"
KAGGLE_DATA_DIR = f"/kaggle/input/key-clf/key_clf_data_{IMG_SIZE}_{IMG_SIZE}/key_clf_data_{IMG_SIZE}_{IMG_SIZE}"

NUM_WORKERS = 4

FAST_DEV_RUN = False
CHECKPOINT_DIR = "resnet/"

# Compute related
ACCELERATOR = "gpu"
DEVICES = [0,1]

id2Label = ['BackSpace', 'Comma', 'Space', 'Stop', 
            'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 
            'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 
            'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 
            'y', 'z']

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
####### 
class KeyClf(L.LightningModule):
    def __init__(self, img_size, frames_per_video, num_classes, learning_rate, weights):
        super().__init__()
        self.model = resnet101(
            sample_size=img_size, 
            sample_duration=frames_per_video,
            shortcut_type='B', 
            num_classes=num_classes
        )
        
        self.loss_fn = torch.nn.CrossEntropyLoss(torch.tensor(weights))
        self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
        self.lr = learning_rate
        self.test_y = []
        self.test_pred = []
        self.save_hyperparameters()

    def test_step(self, batch):
        videos, targets = batch
        preds = self.model(videos)
        self.test_pred.append(preds)
        self.test_y.append(targets)
        print("\ntarget", targets)
        print("preds", torch.argmax(preds, 1))

        loss = self.loss_fn(preds, targets.long())
        test_acc = self.accuracy(preds, targets)
        self.log_dict({'test_acc': test_acc, 'test_loss': loss})


    def on_test_end(self) -> None:
        preds = torch.cat(self.test_pred)
        targets = torch.cat(self.test_y)
        acc = self.accuracy(preds, targets)
        print('acc: ', acc)
#         print("target", targets[:5])
#         print("preds", torch.argmax(preds[:5], 1))
        targets = targets.cpu().numpy()
        preds = torch.argmax(preds, 1).cpu().numpy()
        print(classification_report(targets, preds))
        cm = confusion_matrix(targets, preds)
        plt.figure(figsize=(16, 12))
        sns.heatmap(cm, annot=True, cmap='Blues', fmt='d', xticklabels=id2Label, yticklabels=id2Label)
        plt.xlabel('Predicted labels')
        plt.ylabel('True labels')
        plt.title('Confusion Matrix')
        plt.savefig('result.png')

    
    def training_step(self, batch):
        videos, targets = batch
        preds = self.model(videos)
        loss = self.loss_fn(preds, targets.long())
        self.log_dict({ "train_loss": loss, "train_acc": self.accuracy(preds, targets)}, 
                      on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch):
        videos, targets = batch
        preds = self.model(videos)
        loss = self.loss_fn(preds, targets.long())
        self.log_dict({ "val_loss": loss, "val_acc": self.accuracy(preds, targets)}, 
                      on_step=False, on_epoch=True, prog_bar=True,)
        return loss
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=self.lr)
    


In [12]:
from key_utils import KeyDataset
dataset = KeyDataset(video_name='video_36', 
                     videos_dir='./datasets/raw_frames_320',
                     labels_dir='./datasets/labels',
                     f_after=4,
                     f_before=3,
                     gap=2,
                     total_window=8,
                     color_channel_last=False)

data_loader = DataLoader(dataset,
                          batch_size=4,
                          num_workers=0,
                          shuffle=False)

In [6]:
class KeyClsDataset(Dataset):
    def __init__(self, data_dir, mode):
        self.dataset_root_path = pathlib.Path(data_dir)
        self.all_video_file_paths =  list(self.dataset_root_path.glob(f"{mode}/*/*.mp4"))
        self.class_labels = sorted({str(path).split("/")[-2] for path in self.all_video_file_paths})
        self.label2id = {label: i for i, label in enumerate(self.class_labels)}
        self.id2label  = {i: label for label, i in self.label2id.items()}
  
    def __len__(self):
        return len(self.all_video_file_paths)

    def __getitem__(self, idx):
        file_path = self.all_video_file_paths[idx]
        vframes, _, _ = read_video(file_path, pts_unit='sec')
        label = str(file_path).split("/")[-2]
        # permute to (num_frames, num_channels, height, width)
        vframes = vframes.permute(3, 0, 1, 2).float() / 255.0
        return vframes, self.label2id[label]

In [8]:
dataset = KeyClsDataset(data_dir='datasets/key_clf_data_320_320', mode='test')
data_loader = DataLoader(dataset,
                          batch_size=4,
                          num_workers=0,
                          shuffle=False)

print(dataset[0][0].shape)

torch.Size([3, 8, 320, 320])


In [14]:
weights

array([0.3786203 , 3.46873156, 0.23390522, 5.15745614, 0.53529077,
       2.76032864, 1.29790287, 1.34812267, 0.35877956, 2.41954733,
       2.26898215, 1.35044502, 0.5581583 , 4.54454106, 5.33287982,
       0.98484087, 1.89584845, 0.66435028, 0.62764879, 1.7267254 ,
       5.61959379, 0.69497636, 0.85349301, 0.55323453, 1.11354167,
       4.18097778, 2.99211196, 3.63774169, 2.13897226, 4.04089347])

In [13]:
weights = np.load('key_cls_weights.npy')

model = KeyClf(
    img_size=IMG_SIZE,
    frames_per_video=FRAMES_PER_VIDEO,
    num_classes=NUM_CLASSES,
    learning_rate=LEARNING_RATE,
    weights = [1.0 for i in range(30)]
)

trainer = L.Trainer(accelerator="cpu")

# logger = CSVLogger("logs", name=f"resnet_101_img_{IMG_SIZE}", flush_logs_every_n_steps=1)

# if LOCAL:
#     trainer = L.Trainer(
#                 max_time=MAX_TIME,
#                 callbacks=[EarlyStopping(monitor="val_loss", patience=10)],
#                 fast_dev_run=True,
#                 logger=logger,
#                 accelerator="cpu"
#             )
# else: 
#     trainer = L.Trainer(
#             devices=DEVICES,
#             max_time=MAX_TIME,
#             callbacks=[EarlyStopping(monitor="val_loss", patience=10)],
#             fast_dev_run=FAST_DEV_RUN,
#             logger=logger,
#             accelerator=ACCELERATOR
#         )

# # trainer.fit(model, dm)
trainer.test(model, data_loader, ckpt_path="ckpts/resnet_320_jun_39/epoch=11-step=16551_320.ckpt")

  m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/haily/.pyenv/versions/3.10.4/lib/python3.10/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
Restoring states from the checkpoint path at ckpts/resnet_320_jun_39/epoch=11-step=16551_320.ckpt
/Users/haily/.pyenv/versions/3.10.4/lib/python3.10/site-packages/lightning/pytorch/utilities/migration/utils.py:56: The loaded checkpoint was produced with Lightning v2.3.1, which is newer than your current Lightning version: v2.3.0
Loaded model weights from the checkpoint at ckpts/resnet_320_jun_39/epoch=11-step=16551_320.ckpt
/Users/haily/.pyenv/versions/3.10.4/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consid

Testing DataLoader 0:   0%|          | 0/293 [00:00<?, ?it/s]
target tensor([ 7,  8,  4, 21])
preds tensor([29, 29, 29, 29])
Testing DataLoader 0:   0%|          | 1/293 [00:03<17:09,  0.28it/s]
target tensor([ 2, 23,  8,  4])
preds tensor([29, 29, 29, 29])
Testing DataLoader 0:   1%|          | 2/293 [00:06<15:37,  0.31it/s]
target tensor([16,  1,  2, 12])
preds tensor([29, 29, 29, 29])
Testing DataLoader 0:   1%|          | 3/293 [00:09<14:36,  0.33it/s]
target tensor([ 2, 11, 18, 18])
preds tensor([29, 29, 29, 29])
Testing DataLoader 0:   1%|▏         | 4/293 [00:11<13:23,  0.36it/s]
target tensor([19,  0,  0, 19])
preds tensor([29, 29, 29, 29])
Testing DataLoader 0:   2%|▏         | 5/293 [00:13<12:43,  0.38it/s]
target tensor([ 8,  2, 23, 11])
preds tensor([29, 29, 29, 29])
Testing DataLoader 0:   2%|▏         | 6/293 [00:15<12:13,  0.39it/s]
target tensor([12, 22,  2,  8])
preds tensor([29, 29, 29, 29])
Testing DataLoader 0:   2%|▏         | 7/293 [00:17<11:52,  0.40it/s]
target 

/Users/haily/.pyenv/versions/3.10.4/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
