In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
import os
import numpy as np
import cv2
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import matplotlib.pyplot as plt

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
class VideoFrameDataset(Dataset):
    def __init__(self, data_path, transform=None, num_frames=16):
        self.data_path = data_path
        self.transform = transform
        self.num_frames = num_frames
        self.classes = sorted([d for d in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, d))])
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        self.samples = []

        # Collect video samples
        for class_name in self.classes:
            class_dir = os.path.join(self.data_path, class_name)
            for video_folder in os.listdir(class_dir):
                video_path = os.path.join(class_dir, video_folder)
                if os.path.isdir(video_path):
                    self.samples.append((video_path, self.class_to_idx[class_name]))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        video_path, label = self.samples[idx]
        frame_files = sorted([f for f in os.listdir(video_path) if f.endswith('.jpg')])
        frames = []

        # Select num_frames frames, or less if there are fewer frames
        actual_num_frames = min(self.num_frames, len(frame_files))
        selected_indices = np.linspace(0, len(frame_files) - 1, num=actual_num_frames, dtype=int)

        for i in selected_indices:
            if i < len(frame_files):
                frame_path = os.path.join(video_path, frame_files[i])
                frame = cv2.imread(frame_path)
                if frame is None:
                    frame = np.zeros((224, 224, 3), dtype=np.uint8)
                else:
                    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                if self.transform:
                    frame = self.transform(frame)
                frames.append(frame)

        while len(frames) < self.num_frames:
            frames.append(torch.zeros_like(frames[0]))

        frames = torch.stack(frames)
        return frames, label


In [4]:
class EarlyFusionCNN(nn.Module):
    def __init__(self, num_classes, dropout=0.5, num_frames=16):
        super(EarlyFusionCNN, self).__init__()

        # Pre-trained ResNet18
        self.resnet = models.resnet18(pretrained=True)
        self.resnet.fc = nn.Identity()  # Remove the fully connected layer

        # Number of frames
        self.num_frames = num_frames

        # MLP for classification
        self.fc = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(512, num_classes)  # 512 is the output size from ResNet
        )

    def forward(self, x):
        batch_size, num_frames, c, h, w = x.shape

        # Flatten the frames into a 4D tensor
        x = x.view(batch_size * num_frames, c, h, w)  # (batch_size * num_frames, c, h, w)
        
        # Extract features from ResNet
        x = self.resnet(x)

        # Reshape to (batch_size, num_frames, 512)
        x = x.view(batch_size, num_frames, -1)
        
        # Combine features from frames (average over frames)
        x = x.mean(dim=1)  # Averaging over frames
        
        # Pass through MLP for classification
        x = self.fc(x)
        
        return x

In [5]:
def evaluate_model(model, test_loader, criterion, classes):
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(test_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            test_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    test_loss = test_loss / total
    test_acc = 100 * correct / total
    print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.2f}%')

    # Confusion Matrix
    conf_mat = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(min(20, len(classes)), min(18, len(classes))))
    sns.heatmap(conf_mat, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title('Confusion Matrix')
    plt.tight_layout()
    plt.savefig('confusion_matrix.png')
    plt.close()

    # Classification Report
    print("\nClassification Report:")
    print(classification_report(all_labels, all_preds, target_names=classes))

    return test_loss, test_acc

In [6]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=25, scheduler=None, early_stopping_patience=7):
    best_val_acc = 0.0
    patience_counter = 0
    
    # Lists to store training history
    train_losses = []
    train_accuracies = []
    val_losses = []
    val_accuracies = []

    # Track the actual number of epochs completed due to early stopping
    actual_epochs = 0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        epoch_train_loss = running_loss / len(train_loader.dataset)
        epoch_train_acc = 100 * correct / total
        train_losses.append(epoch_train_loss)
        train_accuracies.append(epoch_train_acc)
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_train_loss:.4f}, Train Accuracy: {epoch_train_acc:.2f}%")
        
        # Validation phase
        model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                running_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        epoch_val_loss = running_loss / len(val_loader.dataset)
        epoch_val_acc = 100 * correct / total
        val_losses.append(epoch_val_loss)
        val_accuracies.append(epoch_val_acc)
        print(f"Epoch {epoch+1}/{num_epochs}, Validation Loss: {epoch_val_loss:.4f}, Validation Accuracy: {epoch_val_acc:.2f}%")
        
        # Step the scheduler if applicable
        if scheduler is not None:
            scheduler.step(epoch_val_loss)
        
        # Early stopping logic
        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
            patience_counter = 0
            # Save best model
            torch.save(model.state_dict(), 'best_early_fusion_cnn_model.pth')
            print(f"  Saved best model with val accuracy: {best_val_acc:.2f}%")
        else:
            patience_counter += 1
            print(f"  Early stopping patience: {patience_counter}/{early_stopping_patience}")
            if patience_counter >= early_stopping_patience:
                print(f"Early stopping triggered at epoch {epoch+1}")
                break

        # Count the actual number of epochs run
        actual_epochs = epoch + 1

    # Plot training and validation loss/accuracy with dynamic epoch count
    plt.figure(figsize=(12, 6))

    # Loss plot
    plt.subplot(1, 2, 1)
    plt.plot(range(1, actual_epochs + 1), train_losses[:actual_epochs], label='Train Loss')
    plt.plot(range(1, actual_epochs + 1), val_losses[:actual_epochs], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training and Validation Loss')

    # Accuracy plot
    plt.subplot(1, 2, 2)
    plt.plot(range(1, actual_epochs + 1), train_accuracies[:actual_epochs], label='Train Accuracy')
    plt.plot(range(1, actual_epochs + 1), val_accuracies[:actual_epochs], label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    plt.title('Training and Validation Accuracy')

    # Save the plot
    plt.tight_layout()
    plt.savefig('training_history.png')
    plt.close()

    return model


In [7]:
def main():
    # Data directories
    data_root = '/kaggle/input/msasl-process/processdata'
    train_data_path = os.path.join(data_root, 'processed_data_MS_ASL100_Train')
    val_data_path = os.path.join(data_root, 'processed_data_MS_ASL100_Val')
    test_data_path = os.path.join(data_root, 'processed_data_MS_ASL100_Test')
    
    # Training parameters
    batch_size = 32
    num_frames = 16
    num_epochs = 50
    learning_rate = 0.0001
    weight_decay = 1e-5
    
    # Data transforms
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
        transforms.RandomRotation(15),
        transforms.RandomCrop(224, padding=4),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Load datasets
    print("Loading datasets...")
    train_dataset = VideoFrameDataset(train_data_path, transform=transform, num_frames=num_frames)
    val_dataset = VideoFrameDataset(val_data_path, transform=transform, num_frames=num_frames)
    test_dataset = VideoFrameDataset(test_data_path, transform=transform, num_frames=num_frames)
    
    print(f"Train samples: {len(train_dataset)}")
    print(f"Val samples: {len(val_dataset)}")
    print(f"Test samples: {len(test_dataset)}")
    
    # Data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    # Get classes
    classes = train_dataset.classes
    num_classes = len(classes)
    print(f"Number of classes: {num_classes}")
    
    # Initialize the model
    print("Initializing Early Fusion CNN model...")
    model = EarlyFusionCNN(num_classes=num_classes).to(device)
    
    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    
    # Learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.1, verbose=True)
    
    # Train the model
    print("Starting training...")
    model = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=num_epochs, scheduler=scheduler, early_stopping_patience=50)
    
    # Evaluate the model
    print("Evaluating model...")
    test_loss, test_acc = evaluate_model(model, test_loader, criterion, classes)
    
    # Save the final model
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'classes': classes,
        'test_acc': test_acc,
    }, 'early_fusion_cnn_final_model.pth')
    
    print(f"Final model saved with test accuracy: {test_acc:.2f}%")

In [8]:
if __name__ == "__main__":
    main()

Loading datasets...
Train samples: 3495
Val samples: 850
Test samples: 1335
Number of classes: 154
Initializing Early Fusion CNN model...


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 171MB/s]


Starting training...




Epoch 1/50, Train Loss: 5.1687, Train Accuracy: 1.40%
Epoch 1/50, Validation Loss: 4.8800, Validation Accuracy: 4.94%
  Saved best model with val accuracy: 4.94%
Epoch 2/50, Train Loss: 4.5864, Train Accuracy: 6.15%
Epoch 2/50, Validation Loss: 4.6179, Validation Accuracy: 8.47%
  Saved best model with val accuracy: 8.47%
Epoch 3/50, Train Loss: 4.1599, Train Accuracy: 13.30%
Epoch 3/50, Validation Loss: 4.3724, Validation Accuracy: 12.59%
  Saved best model with val accuracy: 12.59%
Epoch 4/50, Train Loss: 3.7207, Train Accuracy: 21.49%
Epoch 4/50, Validation Loss: 4.1523, Validation Accuracy: 15.06%
  Saved best model with val accuracy: 15.06%
Epoch 5/50, Train Loss: 3.3138, Train Accuracy: 30.76%
Epoch 5/50, Validation Loss: 3.9710, Validation Accuracy: 19.53%
  Saved best model with val accuracy: 19.53%
Epoch 6/50, Train Loss: 2.9708, Train Accuracy: 37.80%
Epoch 6/50, Validation Loss: 3.8138, Validation Accuracy: 21.06%
  Saved best model with val accuracy: 21.06%
Epoch 7/50, Trai

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
