<a href="https://colab.research.google.com/github/madelyn-redick/LearningASL/blob/cnn/CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision.models import ResNet50_Weights

In [None]:
'''
Convolutional Neural Network for ASL sign recognition.
Extends pytorch's nn.Module.
'''
class ASL_CNN(nn.Module):
  '''
  Initializes the CNN Model.
  Architecture:
    - Transfer learning from resnet50
    - Freeze all layers except the last residual block
    - Replace fully connected layer with
      - Linear -> ReLU -> Dropout -> Linear
    - Output has 24 classes (24 static letter signs)
  '''
  def __init__(self):
    super(ASL_CNN, self).__init__()

    self._base_model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

    for param in self._base_model.parameters():
      param.requires_grad = False

    for param in self._base_model.layer4.parameters():
      param.requires_grad = True

    num_features = self._base_model.fc.in_features # 2048
    self._base_model.fc = nn.Sequential(
        nn.Linear(num_features, 512),
        nn.ReLU(),
        nn.Dropout(0.4),
        nn.Linear(512, 24) # 24 static ASL letters
    )

  '''
    Executes a forward pass of the CNN model.

    Parameters
    x : Tensor
      The input image.

    Returns a tensor of the model's prediction.
  '''
  def forward(self, x):
    return self._base_model(x)

In [None]:
'''
  Function to train a model.

  Parameters
  model : nn.Module
    A pytorch neural network model.
  loss_fn: (Tensor, Tensor) => Tensor (scalar)
    A criterion function to calculate loss given the predictions and labels
  optimizer: nn.optim
    A pytorch optimizer.
  scheduler: torch.optim.lr_scheduler
    A pytorch learning rate scheduler.
  num_epochs: int
    The number of epochs to train the model for.

  Returns the model with weights updated from the best epoch run.
'''
def train_model(model, loss_fn, optimizer, scheduler, num_epochs=20):
  torch.save(model.state_dict(), best_model_params_path)
  best_accuracy = 0.0

  for epoch in range(num_epochs):
    for phase in ['train', 'validation']:
      if phase == 'train':
        model.train()
      else:
        model.eval()

      cumulative_loss = 0.0
      cumulative_corrects = 0

      for inputs, labels in dataloader[phase]:
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        with torch.set_grad_enabled(phase == 'train'):
          outputs = model(inputs)
          _, predictions = torch.max(outputs, 1)
          loss = loss_fn(outputs, labels)

          if phase == 'train':
            loss.backward()
            optimizer.step()

        cumulative_loss += loss.item() * inputs.size(0)
        cumulative_corrects += torch.sum(predictions == labels.data)

      if phase == 'train':
        scheduler.step()

      epoch_loss = cumulative_loss / dataset_sizes[phase]
      epoch_accuracy = cumulative_corrects.double() / dataset_sizes[phase]

      if phase == 'validation' and epoch_accuracy > best_accuracy:
        best_accuracy = epoch_accuracy
        torch.save(model.state_dict(), best_model_params_path)

  model.load_state_dict(torch.load(best_model_params_path, weights_only=True))
  return model

In [None]:
#GPU/CPU and Path Definition
best_model_params_path = './best_cnn_params.pth'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Model Instance and Hyperparameters
model = ASL_CNN()
# optimize only the parameters that a not frozen (AKA requires_grad == True)
# small learning rate because we are using transfer learning and do not want to mess up pretrained weights
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-5, momentum=0.9)
scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
loss = nn.CrossEntropyLoss()

train_model(model=model, loss_fn=loss, optimizer=optimizer, scheduler=scheduler, num_epochs=30)