In [8]:
import lightning as L

In [9]:
FPS = 3
H, W, C = 512, 512, 3
NUM_FRAMES = 30 # 10s

In [1]:
from typing import Iterator, List

import torch
from torch.utils.data import Sampler
import numpy as np


class RandomSubsequenceSampler(Sampler[List[int]]):
    def __init__(self, num_samples: int, batch_size: int) -> None:
        self.num_samples = num_samples
        self.batch_size = batch_size
        self.start_indices = np.random.permutation(len(self))

    def __iter__(self) -> Iterator[List[int]]:
        for index in self.start_indices:
            index *= self.batch_size
            yield torch.arange(index, min(index + self.batch_size, self.num_samples))

    def __len__(self):
        return (self.num_samples + self.batch_size - 1) // self.batch_size

In [8]:
import os
import random
import math
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset

# Define the duration of each chunk in seconds
chunk_duration_s = 10
chunk_duration_frames = 3 * chunk_duration_s

# Define the path to the video frames directory
frames_directory = 'frames/labeled_data/'

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
    ])

# Create the ImageFolder dataset
dataset = ImageFolder(root=frames_directory, transform=transform)

# Calculate the total number of chunks
total_frames = len(dataset)
total_chunks = math.ceil(total_frames / chunk_duration_frames)

# Create a list of chunk indices
chunk_indices = [i for i in range(total_chunks)]

# Split the chunk indices into training and validation sets
train_size = int(0.8 * total_chunks)
random.seed(42)
random.shuffle(chunk_indices)
train_chunk_indices = chunk_indices[:train_size]
val_chunk_indices = chunk_indices[train_size:]

# Create the training and validation subsets
train_indices = [frame_idx for chunk_idx in train_chunk_indices for frame_idx in range(chunk_idx * chunk_duration_s, (chunk_idx + 1) * chunk_duration_s) if frame_idx < total_frames]
val_indices = [frame_idx for chunk_idx in val_chunk_indices for frame_idx in range(chunk_idx * chunk_duration_s, (chunk_idx + 1) * chunk_duration_s) if frame_idx < total_frames]

train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices)

# Create the data loaders for training and validation
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=32, num_workers=4)

In [9]:
import torch
import torch.nn as nn
import torchvision.models as models
import pytorch_lightning as pl

class FineTuneResNet(pl.LightningModule):
    def __init__(self, num_classes, learning_rate=1e-3):
        super().__init__()
        self.resnet = models.resnet34(pretrained=True)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)
        self.learning_rate = learning_rate

    def forward(self, x):
        return self.resnet(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.CrossEntropyLoss()(logits, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.CrossEntropyLoss()(logits, y)
        self.log('val_loss', loss)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

# Instantiate the LightningModule and Trainer
model = FineTuneResNet(num_classes=2)
trainer = pl.Trainer(max_epochs=1)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [11]:
trainer.fit(model, train_dataloader, val_dataloader)


  | Name   | Type   | Params
----------------------------------
0 | resnet | ResNet | 21.3 M
----------------------------------
21.3 M    Trainable params
0         Non-trainable params
21.3 M    Total params
85.143    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

In [None]:
# ~10m on one Epoch! :O