In [1]:
import lightning as L

In [2]:
FPS = 3
H, W, C = 512, 512, 3
NUM_FRAMES = 30 # 10s

In [3]:
import os
import random
import math
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset

# Define the duration of each chunk in seconds
chunk_duration_s = 10
chunk_duration_frames = 3 * chunk_duration_s

# Define the path to the video frames directory
frames_directory = 'frames/labeled_data/'

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
    ])

# Create the ImageFolder dataset
dataset = ImageFolder(root=frames_directory, transform=transform)

# Calculate the total number of chunks
total_frames = len(dataset)
total_chunks = math.ceil(total_frames / chunk_duration_frames)

# Create a list of chunk indices
chunk_indices = [i for i in range(total_chunks)]

# Split the chunk indices into training and validation sets
train_size = int(0.8 * total_chunks)
random.seed(42)
random.shuffle(chunk_indices)
train_chunk_indices = chunk_indices[:train_size]
val_chunk_indices = chunk_indices[train_size:]

# Create the training and validation subsets
train_indices = [frame_idx for chunk_idx in train_chunk_indices for frame_idx in range(chunk_idx * chunk_duration_frames, (chunk_idx + 1) * chunk_duration_frames) if frame_idx < total_frames]
val_indices = [frame_idx for chunk_idx in val_chunk_indices for frame_idx in range(chunk_idx * chunk_duration_frames, (chunk_idx + 1) * chunk_duration_frames) if frame_idx < total_frames]

train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices)

# Create the data loaders for training and validation
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=32, num_workers=4)

In [4]:
sorted(train_indices)[:15], sorted(val_indices)[:15]

([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 [30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44])

In [8]:
from typing import Any, Optional
from pytorch_lightning.utilities.types import STEP_OUTPUT
import torch
import torch.nn as nn
import torchvision.models as models
import pytorch_lightning as pl
import pytorch_lightning.loggers as loggers
import torchmetrics

class FineTuneResNet(pl.LightningModule):
    def __init__(self, num_classes, learning_rate=1e-3):
        super().__init__()
        self.model = models.resnet18(pretrained=True,)
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
        self.learning_rate = learning_rate
        self.accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes)
        self.val_accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.CrossEntropyLoss()(logits, y)
        self.log('train_loss', loss)
        self.accuracy(logits, y)
        self.log('train_acc_step', self.accuracy)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.CrossEntropyLoss()(logits, y)
        self.log('val_loss', loss)
        self.val_accuracy(logits, y)
        self.log('val_acc_step', self.val_accuracy)
    
    def on_validation_batch_end(self, outputs: STEP_OUTPUT | None, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
        self.log("valid_acc_epoch", self.val_accuracy)
    

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

# Instantiate the LightningModule and Trainer
model = FineTuneResNet(num_classes=2)
trainer = pl.Trainer(max_epochs=1)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [9]:
trainer.fit(model, train_dataloader, val_dataloader)


  | Name         | Type               | Params
----------------------------------------------------
0 | model        | ResNet             | 11.2 M
1 | accuracy     | MulticlassAccuracy | 0     
2 | val_accuracy | MulticlassAccuracy | 0     
----------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.710    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  tp = tp.sum(dim=0 if multidim_average == "global" else 1)


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


In [None]:
trainer.save_checkpoint("lightning.ckpt")

In [None]:
# ~30m, as fast as Fast.AI really... 

In [None]:
# ~10m on one Epoch! :O