In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import cv2
import os

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
from torch.utils.data import Dataset  # Import Dataset
import os
import torch
import cv2
from torchvision import transforms
import numpy as np
class VideoFrameDataset(Dataset):
    def __init__(self, data_path, transform=None, num_frames=16):
        self.data_path = data_path
        self.transform = transform
        self.num_frames = num_frames
        self.classes = sorted([d for d in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, d))])
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        self.samples = []
        
        # Lấy video và nhãn từ thư mục
        for class_name in self.classes:
            class_dir = os.path.join(self.data_path, class_name)
            for video_folder in os.listdir(class_dir):
                video_path = os.path.join(class_dir, video_folder)
                if os.path.isdir(video_path):
                    self.samples.append((video_path, self.class_to_idx[class_name]))

    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        video_path, label = self.samples[idx]
        frame_files = sorted([f for f in os.listdir(video_path) if f.endswith('.jpg')])
        frames = []
        
        # Chọn num_frames frame, hoặc tất cả frame nếu ít hơn
        actual_num_frames = min(self.num_frames, len(frame_files))
        selected_indices = np.linspace(0, len(frame_files) - 1, num=actual_num_frames, dtype=int)
        
        for i in selected_indices:
            if i < len(frame_files):
                frame_path = os.path.join(video_path, frame_files[i])
                frame = cv2.imread(frame_path)
                if frame is None:
                    frame = np.zeros((224, 224, 3), dtype=np.uint8)
                else:
                    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                if self.transform:
                    frame = self.transform(frame)
                frames.append(frame)
        
        while len(frames) < self.num_frames:
            frames.append(torch.zeros_like(frames[0]))
        
        frames = torch.stack(frames)
        return frames, label


In [4]:
# Mô hình CNN (Single-Frame CNN) với ResNet + MLP
class SingleFrameCNN(nn.Module):
    def __init__(self, num_classes, dropout=0.5):
        super(SingleFrameCNN, self).__init__()
        
        # Sử dụng ResNet18 đã được pre-trained
        self.resnet = models.resnet18(pretrained=True)
        self.resnet.fc = nn.Identity()  # Bỏ fully connected layer cuối
        
        # MLP để phân loại
        self.fc = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(512, num_classes)  # 512 là kích thước đầu ra của ResNet
        )

    def forward(self, x):
        batch_size, num_frames, c, h, w = x.shape
        # Reshape để xử lý từng frame một
        x = x.view(batch_size * num_frames, c, h, w)
        # Trích xuất đặc trưng từ ResNet
        x = self.resnet(x)
        # Reshape lại để MLP xử lý
        x = x.view(batch_size, num_frames, -1)
        
        # Tính trung bình các logits
        x = x.mean(dim=1)  # Trung bình trên các frame
        
        # Đưa qua MLP để phân loại
        x = self.fc(x)
        
        return x

In [5]:
def evaluate_model(model, test_loader, criterion, classes):
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(test_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            test_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    test_loss = test_loss / total
    test_acc = 100 * correct / total
    print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.2f}%')

    # Confusion Matrix
    conf_mat = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(min(20, len(classes)), min(18, len(classes))))
    sns.heatmap(conf_mat, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title('Confusion Matrix')
    plt.tight_layout()
    plt.savefig('confusion_matrix.png')
    plt.close()

    # Classification Report
    print("\nClassification Report:")
    print(classification_report(all_labels, all_preds, target_names=classes))

    return test_loss, test_acc


In [6]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


In [7]:
import matplotlib.pyplot as plt

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=25, scheduler=None, early_stopping_patience=50):
    best_val_acc = 0.0
    patience_counter = 0
    
    # Lists to store training history
    train_losses = []
    train_accuracies = []
    val_losses = []
    val_accuracies = []

    # Track the actual number of epochs completed due to early stopping
    actual_epochs = 0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        epoch_train_loss = running_loss / len(train_loader.dataset)
        epoch_train_acc = 100 * correct / total
        train_losses.append(epoch_train_loss)
        train_accuracies.append(epoch_train_acc)
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_train_loss:.4f}, Train Accuracy: {epoch_train_acc:.2f}%")
        
        # Validation phase
        model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                running_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        epoch_val_loss = running_loss / len(val_loader.dataset)
        epoch_val_acc = 100 * correct / total
        val_losses.append(epoch_val_loss)
        val_accuracies.append(epoch_val_acc)
        print(f"Epoch {epoch+1}/{num_epochs}, Validation Loss: {epoch_val_loss:.4f}, Validation Accuracy: {epoch_val_acc:.2f}%")
        
        # Step the scheduler if applicable
        if scheduler is not None:
            scheduler.step(epoch_val_loss)
        
        # Early stopping logic
        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
            patience_counter = 0
            # Save best model
            torch.save(model.state_dict(), 'best_single_frame_cnn_model.pth')
            print(f"  Saved best model with val accuracy: {best_val_acc:.2f}%")
        else:
            patience_counter += 1
            print(f"  Early stopping patience: {patience_counter}/{early_stopping_patience}")
            if patience_counter >= early_stopping_patience:
                print(f"Early stopping triggered at epoch {epoch+1}")
                break

        # Count the actual number of epochs run
        actual_epochs = epoch + 1

    # Plot training and validation loss/accuracy with dynamic epoch count
    plt.figure(figsize=(12, 6))

    # Loss plot
    plt.subplot(1, 2, 1)
    plt.plot(range(1, actual_epochs + 1), train_losses[:actual_epochs], label='Train Loss')
    plt.plot(range(1, actual_epochs + 1), val_losses[:actual_epochs], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training and Validation Loss')

    # Accuracy plot
    plt.subplot(1, 2, 2)
    plt.plot(range(1, actual_epochs + 1), train_accuracies[:actual_epochs], label='Train Accuracy')
    plt.plot(range(1, actual_epochs + 1), val_accuracies[:actual_epochs], label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    plt.title('Training and Validation Accuracy')

    # Save the plot
    plt.tight_layout()
    plt.savefig('training_history.png')
    plt.close()

    return model


In [8]:
def main():
    # Cấu hình dữ liệu
    data_root = '/kaggle/input/msasl-process/processdata'
    train_data_path = os.path.join(data_root, 'processed_data_MS_ASL100_Train')
    val_data_path = os.path.join(data_root, 'processed_data_MS_ASL100_Val')
    test_data_path = os.path.join(data_root, 'processed_data_MS_ASL100_Test')
    
    # Các tham số huấn luyện
    batch_size = 32
    num_frames = 16
    num_epochs = 50
    learning_rate = 0.0001
    weight_decay = 1e-5
    
    # Định nghĩa các phép biến đổi 
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
        transforms.RandomRotation(15),
        transforms.RandomCrop(224, padding=4),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Tạo datasets
    print("Loading datasets...")
    train_dataset = VideoFrameDataset(train_data_path, transform=transform, num_frames=num_frames)
    val_dataset = VideoFrameDataset(val_data_path, transform=transform, num_frames=num_frames)
    test_dataset = VideoFrameDataset(test_data_path, transform=transform, num_frames=num_frames)
    
    print(f"Train samples: {len(train_dataset)}")
    print(f"Val samples: {len(val_dataset)}")
    print(f"Test samples: {len(test_dataset)}")
    
    # Tạo dataloaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    # Lấy danh sách các lớp
    classes = train_dataset.classes
    num_classes = len(classes)
    print(f"Number of classes: {num_classes}")
    
    # Khởi tạo mô hình
    print("Initializing Single-Frame CNN model...")
    model = SingleFrameCNN(num_classes=num_classes).to(device)
    
    # Định nghĩa loss function và optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    
    # Learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.1, verbose=True)
    
    # Huấn luyện mô hình
    print("Starting training...")
    model = train_model(
        model, train_loader, val_loader, criterion, optimizer, 
        num_epochs=num_epochs, scheduler=scheduler, early_stopping_patience=10
    )
    
    # Kiểm tra và tải mô hình tốt nhất
    model_path = 'best_single_frame_cnn_model.pth'
    if os.path.exists(model_path):
        print("Loading best model...")
        model.load_state_dict(torch.load(model_path))
    else:
        print(f"Model not found at {model_path}, skipping model loading.")
    
    # Đánh giá mô hình
    print("Evaluating model...")
    test_loss, test_acc = evaluate_model(model, test_loader, criterion, classes)
    
    # Lưu mô hình cuối cùng
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'classes': classes,
        'test_acc': test_acc,
    }, 'single_frame_cnn_final_model.pth')
    
    print(f"Final model saved with test accuracy: {test_acc:.2f}%")


In [9]:
if __name__ == "__main__":
    main()

Loading datasets...
Train samples: 3495
Val samples: 850
Test samples: 1335
Number of classes: 154
Initializing Single-Frame CNN model...


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 208MB/s]


Starting training...




Epoch 1/50, Train Loss: 5.1461, Train Accuracy: 1.29%
Epoch 1/50, Validation Loss: 4.8401, Validation Accuracy: 3.53%
  Saved best model with val accuracy: 3.53%
Epoch 2/50, Train Loss: 4.5500, Train Accuracy: 7.10%
Epoch 2/50, Validation Loss: 4.5171, Validation Accuracy: 9.41%
  Saved best model with val accuracy: 9.41%
Epoch 3/50, Train Loss: 4.0685, Train Accuracy: 15.57%
Epoch 3/50, Validation Loss: 4.3517, Validation Accuracy: 12.00%
  Saved best model with val accuracy: 12.00%
Epoch 4/50, Train Loss: 3.6655, Train Accuracy: 23.58%
Epoch 4/50, Validation Loss: 4.0719, Validation Accuracy: 16.24%
  Saved best model with val accuracy: 16.24%
Epoch 5/50, Train Loss: 3.2732, Train Accuracy: 32.19%
Epoch 5/50, Validation Loss: 3.8512, Validation Accuracy: 20.59%
  Saved best model with val accuracy: 20.59%
Epoch 6/50, Train Loss: 2.9172, Train Accuracy: 39.77%
Epoch 6/50, Validation Loss: 3.8239, Validation Accuracy: 19.18%
  Early stopping patience: 1/10
Epoch 7/50, Train Loss: 2.567

  model.load_state_dict(torch.load(model_path))


Test Loss: 1.4337, Test Accuracy: 67.72%

Classification Report:
              precision    recall  f1-score   support

      afraid       1.00      1.00      1.00         2
       again       0.09      0.40      0.15         5
         all       0.00      0.00      0.00         4
       apple       1.00      0.67      0.80         3
        aunt       0.00      0.00      0.00         2
         bad       0.00      0.00      0.00         2
    bathroom       0.00      0.00      0.00         1
   beautiful       0.00      0.00      0.00         6
     bicycle       0.60      0.60      0.60         5
       black       0.50      0.67      0.57         3
        blue       0.00      0.00      0.00         2
        book       1.00      0.25      0.40         8
         boy       0.00      0.00      0.00         4
   boyfriend       1.00      0.33      0.50         3
     brother       0.25      0.20      0.22         5
       brown       1.00      0.33      0.50         3
         bus    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
