Using https://ensemble-pytorch.readthedocs.io/en/latest/parameters.html#voting

In [1]:
import sys

sys.path.insert(1, "../")

from nets.SimpleNet.lib.simple_model import CNN
from nets.SimpleNet.config.torch_config import get_transform
from nets.SimpleNet.utils.video_dataset import VideoFrameDataset
from nets.SimpleNet.config.dataset import get_dataset_path
from nets.SimpleNet.utils.loader import split_dataset


In [2]:
NUM_SEGMENTS = 10
FRAMES_PER_SEGMENT = 5
BATCH_SIZE = 16
IMAGE_SIZE = 64
IMAGE_RANDOM_CROP_RESIZE = 0.8
NUM_EPOCHS = 5
LEARNING_RATE = 1e-3
DEBUG = False



In [3]:
data_path, model_path = get_dataset_path(dataset="WLASL/videos", model_name="WLASL")

In [4]:
multiple_transform = get_transform(IMAGE_SIZE, IMAGE_RANDOM_CROP_RESIZE)

In [5]:
dataset = VideoFrameDataset(
    root_path=data_path,
    transform=multiple_transform,
    num_segments=NUM_SEGMENTS,
    frames_per_segment=FRAMES_PER_SEGMENT,
)

classes = dataset.classes
print(classes)


['before', 'book', 'candy', 'chair', 'clothes']


In [6]:
train_loader, test_loader, validation_loader = split_dataset(
    dataset, train_split=0.70, validation_split=0.1, batch_size=BATCH_SIZE
)


In [7]:
from torchensemble import VotingClassifier

ensemble = VotingClassifier(
    estimator=CNN,
    n_estimators=10,
    cuda=True,
    estimator_args={
        "num_classes": len(classes),
        "batch_size": BATCH_SIZE,
        "num_frames": FRAMES_PER_SEGMENT * NUM_SEGMENTS,
        "image_size": int(IMAGE_SIZE * IMAGE_RANDOM_CROP_RESIZE),
        "debug": DEBUG,
    },
)


In [8]:
from torch import nn, optim

ensemble.set_criterion(nn.CrossEntropyLoss())
ensemble.set_optimizer("Adam", lr=LEARNING_RATE)
# ensemble.set_scheduler(
#     "ReduceLROnPlateau", mode="min", min_lr=1e-6, factor=0.7, patience=5
# )
ensemble.set_scheduler(
    "CosineAnnealingLR",                    # type of learning rate scheduler
    T_max=NUM_EPOCHS,                           # additional arguments on the scheduler
)

In [9]:
ensemble.fit(
    train_loader=train_loader,
    test_loader=validation_loader,
    epochs=NUM_EPOCHS,
)


Estimator: 000 | Epoch: 000 | Batch: 000 | Loss: 1.61338 | Correct: 4/16
Estimator: 001 | Epoch: 000 | Batch: 000 | Loss: 1.60008 | Correct: 5/16
Estimator: 002 | Epoch: 000 | Batch: 000 | Loss: 1.62123 | Correct: 4/16
Estimator: 003 | Epoch: 000 | Batch: 000 | Loss: 1.62183 | Correct: 2/16
Estimator: 004 | Epoch: 000 | Batch: 000 | Loss: 1.70372 | Correct: 1/16
Estimator: 005 | Epoch: 000 | Batch: 000 | Loss: 1.69355 | Correct: 0/16
Estimator: 006 | Epoch: 000 | Batch: 000 | Loss: 1.62975 | Correct: 5/16
Estimator: 007 | Epoch: 000 | Batch: 000 | Loss: 1.67745 | Correct: 3/16
Estimator: 008 | Epoch: 000 | Batch: 000 | Loss: 1.67161 | Correct: 1/16
Estimator: 009 | Epoch: 000 | Batch: 000 | Loss: 1.55708 | Correct: 5/16
Estimator: 000 | Epoch: 001 | Batch: 000 | Loss: 1.55877 | Correct: 2/16
Estimator: 001 | Epoch: 001 | Batch: 000 | Loss: 1.56361 | Correct: 3/16
Estimator: 002 | Epoch: 001 | Batch: 000 | Loss: 1.60525 | Correct: 5/16
Estimator: 003 | Epoch: 001 | Batch: 000 | Loss: 1.

In [None]:
accuracy = ensemble.evaluate(test_loader)
print(f"Accuracy: {accuracy}")
