### Imports dependencies

In [1]:
import torch

from torch.utils.data import DataLoader

### Define the device to use in training

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

### Setup Dataset

Train Dataset

In [3]:
from datasets import ASLTrainDataset

train_dataset = ASLTrainDataset(dir="./dataset/asl_alphabet_train")
train_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)

Test Dataset

### Initialize ASL Model

In [4]:
from model import ASLModel

model = ASLModel(in_channels=3, out_channels=64).to(device)

### Define Loss and Optimizer of the Model

In [5]:
from torch.nn import CrossEntropyLoss
from torch.optim import Adam

In [6]:
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001, weight_decay=1e-4)

### Training model

Define training function

In [7]:
from torcheval.metrics.functional import multiclass_f1_score

In [8]:
def train(epoch):

    iteration_count = len(train_loader)

    for e in range(epoch):
        for i, (x, y) in enumerate(train_loader):
            x = x.to(device)
            y = y.to(device)
            out = model(x)
           
            loss = criterion(out, y) 
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            score = multiclass_f1_score(out, y, num_classes=29)

            if i  % (iteration_count / 20) == 0:
                print(f"epoch {e} | iteration: {i}/{iteration_count} | loss: {loss.item()} | f1-score: {score}")

In [9]:
torch.cuda.empty_cache()
train(epoch=3)

epoch 0 | iteration: 0/680 | loss: 3.3659255504608154 | f1-score: 0.0234375
epoch 0 | iteration: 34/680 | loss: 3.377070426940918 | f1-score: 0.0078125
epoch 0 | iteration: 68/680 | loss: 3.249237298965454 | f1-score: 0.078125
epoch 0 | iteration: 102/680 | loss: 3.013221502304077 | f1-score: 0.15625
epoch 0 | iteration: 136/680 | loss: 2.678464412689209 | f1-score: 0.171875
epoch 0 | iteration: 170/680 | loss: 2.4586472511291504 | f1-score: 0.2890625
epoch 0 | iteration: 204/680 | loss: 2.117218255996704 | f1-score: 0.3515625
epoch 0 | iteration: 238/680 | loss: 2.1067874431610107 | f1-score: 0.34375
epoch 0 | iteration: 272/680 | loss: 1.7907192707061768 | f1-score: 0.4140625
epoch 0 | iteration: 306/680 | loss: 1.8849259614944458 | f1-score: 0.421875
epoch 0 | iteration: 340/680 | loss: 1.645187258720398 | f1-score: 0.421875
epoch 0 | iteration: 374/680 | loss: 1.4462394714355469 | f1-score: 0.5078125
epoch 0 | iteration: 408/680 | loss: 1.298228144645691 | f1-score: 0.5546875
epoch

### Save the model

In [11]:
torch.save(model.state_dict(), './models/dev.pt')