In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
from torchvision import transforms, models
import numpy as np
import os

In [2]:

class AlphabetDataset(Dataset):
    def __init__(self, root_dir):
        """
        Args:
            root_dir (str): Root directory containing class folders with images.
        """
        self.root_dir = root_dir
        self.data = []
        self.class_to_idx = {}

        # Define image preprocessing (same as ImageNet normalization)
        self.transform = transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

        # Load all image paths and labels
        for class_name in sorted(os.listdir(root_dir)):  # Ensure consistent class order
            class_path = os.path.join(root_dir, class_name)
            if os.path.isdir(class_path):  # Ensure it's a directory
                if class_name not in self.class_to_idx:
                    self.class_to_idx[class_name] = len(self.class_to_idx)
                
                for filename in os.listdir(class_path):
                    file_path = os.path.join(class_path, filename)
                    if file_path.endswith('.jpg') or file_path.endswith('.png'):  # Accept common image formats
                        self.data.append((file_path, self.class_to_idx[class_name]))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        """
        Returns:
            image (torch.Tensor): Preprocessed image tensor.
            label (int): Corresponding class index.
        """
        file_path, label = self.data[idx]

        # Load and transform the image
        img = Image.open(file_path).convert("RGB")  # Ensure it's RGB
        img = self.transform(img)  # Apply transformations (no batch dimension)

        return img, label


In [46]:

root_directory = "/workspace/dataset/asl_alphabet_train/asl_alphabet_train"

dataset = AlphabetDataset(root_directory)
dataloader = DataLoader(dataset, batch_size=256, shuffle=True)


train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=256, shuffle=False)  # No shuffling for validation

In [47]:
import timm

class ImageModel(nn.Module):
    def __init__(self, n_class=28):

        super().__init__()

        efficientnet = timm.create_model('efficientnet_lite0', pretrained=True)
        self.feature_extractor = torch.nn.Sequential(*list(efficientnet.children())[:-4])  # Remove classifier
        for param in self.feature_extractor.parameters():
            param.requires_grad = False

        self.feature_extractor.eval()

        self.conv1x1 = nn.Conv2d(in_channels=320, out_channels=24, kernel_size=1, stride=1, padding=0, bias=False)
        
        self.linear = nn.Linear(1176, 128)
        self.head = nn.Linear(128, n_class)
        self.relu = nn.ReLU()

    def forward(self, x):

        with torch.no_grad():
            x = self.feature_extractor(x)  # (B, 1280, 1, 1)

        x = self.conv1x1(x)
        x = self.relu(x)
        x = x.reshape(-1, 1176)
        x = self.relu(self.linear(x))  # (B, 128)
        x = self.head(x)  # (B, n_class)
        return x

    def num_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)     

In [48]:

model = ImageModel(
    n_class = 29
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(device)
print(model)
print(f"Number of parameters: {model.num_params()}")


cuda
ImageModel(
  (feature_extractor): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNormAct2d(
      32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
      (drop): Identity()
      (act): ReLU6(inplace=True)
    )
    (2): Sequential(
      (0): Sequential(
        (0): DepthwiseSeparableConv(
          (conv_dw): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn1): BatchNormAct2d(
            32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
            (act): ReLU6(inplace=True)
          )
          (aa): Identity()
          (se): Identity()
          (conv_pw): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNormAct2d(
            16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
            (act): Identity()
     

In [49]:

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
epochs = 500


In [None]:
for epoch in range(epochs):
    # ===== Training Phase =====
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    print(f"\nEpoch {epoch+1}/{epochs}")

    model.train()
    for batch_idx, (batch_data, batch_labels) in enumerate(train_dataloader):
        
        batch_data = torch.Tensor(batch_data.tolist()).float().to(device)
        batch_labels = torch.Tensor(batch_labels.tolist()).long().to(device)
        optimizer.zero_grad()
        logits = model(batch_data)
        
        loss = criterion(logits, batch_labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()

        # Compute Accuracy
        preds = torch.argmax(logits, dim=1)
        batch_correct = (preds == batch_labels).sum().item()
        batch_total = batch_labels.size(0)
        correct += batch_correct
        total += batch_total

        # Print loss and accuracy for this batch
        batch_acc = batch_correct / batch_total * 100  # Convert to percentage
        print(f"Batch {batch_idx+1}/{len(train_dataloader)} - Loss: {loss.item():.4f} - Accuracy: {batch_acc:.2f}%")


    avg_train_loss = total_loss / len(train_dataloader)
    train_accuracy = (correct / total) * 100
    print(f"Training - Avg Loss: {avg_train_loss:.4f}, Accuracy: {train_accuracy:.2f}%")

    # ===== Validation Phase =====
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():  # Disable gradient computation
        model.eval()
        for batch_data, batch_labels in val_dataloader:
            batch_data = torch.Tensor(batch_data.tolist()).float().to(device)
            batch_labels = torch.Tensor(batch_labels.tolist()).long().to(device)

            logits = model(batch_data)
            loss = criterion(logits, batch_labels)
            val_loss += loss.item()

            # Compute Accuracy
            preds = torch.argmax(logits, dim=1)
            val_correct += (preds == batch_labels).sum().item()
            val_total += batch_labels.size(0)

    avg_val_loss = val_loss / len(val_dataloader)
    val_accuracy = (val_correct / val_total) * 100
    print(f"Validation - Avg Loss: {avg_val_loss:.4f}, Accuracy: {val_accuracy:.2f}%\n")

print("Training complete.")


Epoch 1/500
Batch 1/272 - Loss: 3.4071 - Accuracy: 3.12%
Batch 2/272 - Loss: 3.3604 - Accuracy: 8.98%
Batch 3/272 - Loss: 3.3480 - Accuracy: 4.30%
Batch 4/272 - Loss: 3.2701 - Accuracy: 8.20%
Batch 5/272 - Loss: 3.2427 - Accuracy: 11.33%
Batch 6/272 - Loss: 3.1237 - Accuracy: 17.97%
Batch 7/272 - Loss: 3.1023 - Accuracy: 14.84%
Batch 8/272 - Loss: 3.0337 - Accuracy: 16.41%
Batch 9/272 - Loss: 2.9498 - Accuracy: 20.31%
Batch 10/272 - Loss: 2.9310 - Accuracy: 20.31%
Batch 11/272 - Loss: 2.7912 - Accuracy: 30.08%
Batch 12/272 - Loss: 2.6559 - Accuracy: 37.89%
Batch 13/272 - Loss: 2.5506 - Accuracy: 37.89%
Batch 14/272 - Loss: 2.5120 - Accuracy: 34.77%
Batch 15/272 - Loss: 2.3924 - Accuracy: 39.45%
Batch 16/272 - Loss: 2.3047 - Accuracy: 39.45%
Batch 17/272 - Loss: 2.1494 - Accuracy: 44.53%
Batch 18/272 - Loss: 2.1159 - Accuracy: 44.14%
Batch 19/272 - Loss: 2.0062 - Accuracy: 49.61%
Batch 20/272 - Loss: 1.9289 - Accuracy: 52.73%
Batch 21/272 - Loss: 1.8183 - Accuracy: 51.17%
Batch 22/272 