In [2]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torchmetrics import functional as torchmetrics

TOT_ACTION_CLASSES = 4  # Update this based on your class count
WINDOW_SIZE = 30

class PoseDataset(Dataset):
    def __init__(self, folder_path, classes):
        self.data = []
        self.labels = []
        self.folder_path = folder_path
        self.classes = classes
        self.load_data()

    def load_keypoints(self, file_path):
        keypoints = []
        with open(file_path, 'r') as f:
            for line in f:
                try:
                    keypoint_row = [float(x) for x in line.strip().split(',')]
                    if len(keypoint_row) == 44:
                        keypoints.append(keypoint_row)
                    else:
                        print(f"Skipping malformed row in {file_path}: {line}")
                except ValueError as e:
                    print(f"Error converting values in {file_path}: {e}")
                    continue
        
        if len(keypoints) == 0:
            print(f"Skipping empty file: {file_path}")
            return None

        return np.array(keypoints, dtype=np.float32)

    def load_data(self):
        for class_idx, class_name in enumerate(self.classes):
            class_folder = os.path.join(self.folder_path, class_name)
            for txt_file in os.listdir(class_folder):
                file_path = os.path.join(class_folder, txt_file)
                keypoints = self.load_keypoints(file_path)
                if keypoints is None:
                    continue

                if keypoints.ndim != 2 or keypoints.shape[1] != 44:
                    print(f"Invalid keypoint shape in {file_path}: {keypoints.shape}")
                    continue

                if keypoints.shape[0] < 9:
                    print(f"Skipping file with fewer than 9 frames: {file_path}")
                    continue

                if keypoints.shape[0] < WINDOW_SIZE:
                    pad_length = WINDOW_SIZE - keypoints.shape[0]
                    keypoints = np.pad(keypoints, ((0, pad_length), (0, 0)), 'constant', constant_values=0)
                elif keypoints.shape[0] > WINDOW_SIZE:
                    keypoints = keypoints[:WINDOW_SIZE]

                self.data.append(keypoints)
                self.labels.append(class_idx)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return torch.tensor(self.data[idx], dtype=torch.float32), torch.tensor(self.labels[idx], dtype=torch.long)


class PoseDataModule:
    def __init__(self, data_root, batch_size, classes):
        self.data_root = data_root
        self.batch_size = batch_size
        self.classes = classes

    def setup(self, stage=None):
        train_folder = os.path.join(self.data_root, 'train')
        val_folder = os.path.join(self.data_root, 'val')

        self.train_dataset = PoseDataset(train_folder, self.classes)
        self.val_dataset = PoseDataset(val_folder, self.classes)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False)








class ActionClassificationBiLSTM(nn.Module):
    def __init__(self, input_features, hidden_dim, num_layers=3, learning_rate=0.001):
        super().__init__()
        
        # BiLSTM layers
        self.lstm = nn.LSTM(
            input_features, 
            hidden_dim, 
            num_layers=num_layers, 
            batch_first=True, 
            bidirectional=True, 
            dropout=0.3  # Dropout between stacked LSTM layers
        )
        
        # Fully connected layers with batch normalization and dropout
        self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim)
        self.batch_norm_fc1 = nn.BatchNorm1d(hidden_dim)
        
        # Attention mechanism
        self.attention = nn.Linear(hidden_dim * 2, 1)
        
        # Final output layer
        self.fc2 = nn.Linear(hidden_dim, TOT_ACTION_CLASSES)
        
        # Dropout
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        # BiLSTM layers
        lstm_out, _ = self.lstm(x)  # Output: [batch, seq_len, hidden_dim*2]
        
        # Attention weights
        attn_weights = torch.softmax(self.attention(lstm_out), dim=1)  # [batch, seq_len, 1]
        
        # Weighted sum of LSTM output with attention
        weighted_lstm_out = torch.sum(lstm_out * attn_weights, dim=1)  # [batch, hidden_dim*2]
        
        # First fully connected layer with batch norm and ReLU
        x = self.fc1(weighted_lstm_out)
        x = self.batch_norm_fc1(x)
        x = torch.relu(x)
        
        # Dropout
        x = self.dropout(x)
        
        # Output layer
        x = self.fc2(x)
        
        return x





    def train_step(self, batch, optimizer):
        x, y = batch
        y = torch.squeeze(y).long()

        y_pred = self(x)
        loss = F.cross_entropy(y_pred, y)
        acc = torchmetrics.accuracy(y_pred, y, task='multiclass', num_classes=TOT_ACTION_CLASSES)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return {'loss': loss, 'acc': acc}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y = torch.squeeze(y).long()

        y_pred = self(x)
        loss = F.cross_entropy(y_pred, y)
        acc = torchmetrics.accuracy(y_pred, y, task='multiclass', num_classes=TOT_ACTION_CLASSES)

        return {'loss': loss, 'acc': acc}


def do_training_validation(data_root, batch_size, learning_rate, classes, epochs=150):
    model = ActionClassificationBiLSTM(input_features=44, hidden_dim=64, num_layers=3, learning_rate=learning_rate)
    data_module = PoseDataModule(data_root, batch_size, classes)
    data_module.setup()

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, min_lr=1e-15, verbose=True)

    for epoch in range(epochs):
        print(f"Epoch {epoch + 1} started")
        model.train()

        for batch in data_module.train_dataloader():
            result = model.train_step(batch, optimizer)
            loss = result['loss'].item()
            acc = result['acc'].item()

            print(f"Train loss: {loss:.2f}, Train accuracy: {acc:.2f}")

        model.eval()
        with torch.no_grad():
            val_loss = 0
            val_acc = 0
            num_val_batches = 0
            for batch in data_module.val_dataloader():
                num_val_batches += 1
                result = model.validation_step(batch, epoch)
                val_loss += result['loss'].item()
                val_acc += result['acc'].item()

            val_loss /= num_val_batches
            val_acc /= num_val_batches

            print(f"Validation loss: {val_loss:.2f}, Validation accuracy: {val_acc:.2f}")

        scheduler.step(val_loss)

    return model


# Example usage
DATASET_PATH = r"D:\gaurav\shopper_mediapipe_handpose\merl_classification\4_class_training"
CLASSES = ['action_action_0', 'action_action_1', 'action_action_2', 'action_action_3']
model = do_training_validation(DATASET_PATH, batch_size=30, learning_rate=0.001, classes=CLASSES, epochs=80)


Skipping empty file: D:\gaurav\shopper_mediapipe_handpose\merl_classification\4_class_training\train\action_action_0\10_1_action_00_000003.txt
Skipping empty file: D:\gaurav\shopper_mediapipe_handpose\merl_classification\4_class_training\train\action_action_0\11_3_action_00_000004.txt
Skipping empty file: D:\gaurav\shopper_mediapipe_handpose\merl_classification\4_class_training\train\action_action_0\11_3_action_00_000010.txt
Skipping file with fewer than 9 frames: D:\gaurav\shopper_mediapipe_handpose\merl_classification\4_class_training\train\action_action_0\11_3_action_00_000011.txt
Skipping file with fewer than 9 frames: D:\gaurav\shopper_mediapipe_handpose\merl_classification\4_class_training\train\action_action_0\12_1_action_00_000006.txt
Skipping empty file: D:\gaurav\shopper_mediapipe_handpose\merl_classification\4_class_training\train\action_action_0\12_2_action_00_000004.txt
Skipping empty file: D:\gaurav\shopper_mediapipe_handpose\merl_classification\4_class_training\train\act

In [4]:
model.eval()

ActionClassificationBiLSTM(
  (lstm): LSTM(44, 64, num_layers=3, batch_first=True, dropout=0.3, bidirectional=True)
  (fc1): Linear(in_features=128, out_features=64, bias=True)
  (batch_norm_fc1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (attention): Linear(in_features=128, out_features=1, bias=True)
  (fc2): Linear(in_features=64, out_features=4, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
)

In [14]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

# Constants
WINDOW_SIZE = 30
CLASSES = ['action_action_0', 'action_action_1', 'action_action_2', 'action_action_3']

class PoseInferenceDataset(Dataset):
    def __init__(self, folder_path, classes):
        self.data = []
        self.folder_path = folder_path
        self.classes = classes
        self.load_data()

    def load_keypoints(self, file_path):
        keypoints = []
        with open(file_path, 'r') as f:
            for line in f:
                try:
                    # Split the line by commas and convert to float
                    keypoint_row = [float(x) for x in line.strip().split(',')]
                    if len(keypoint_row) == 44:  # Ensure there are exactly 44 values
                        keypoints.append(keypoint_row)
                    else:
                        print(f"Skipping malformed row in {file_path}: {line}")
                except ValueError as e:
                    print(f"Error converting values in {file_path}: {e}")
                    continue

        # Skip if no keypoints are found
        if len(keypoints) == 0:
            print(f"Skipping empty file: {file_path}")
            return None

        return np.array(keypoints, dtype=np.float32)

    def load_data(self):
        for txt_file in os.listdir(self.folder_path):
            file_path = os.path.join(self.folder_path, txt_file)
            keypoints = self.load_keypoints(file_path)

            if keypoints is None:
                continue

            # Skip sequences with fewer than 9 frames
            if keypoints.shape[0] < 9:
                print(f"Skipping file with fewer than 9 frames: {file_path}")
                continue

            # Handle sequence length for WINDOW_SIZE (30 frames)
            if keypoints.shape[0] < WINDOW_SIZE:
                # Pad with zeros if less than WINDOW_SIZE
                pad_length = WINDOW_SIZE - keypoints.shape[0]
                keypoints = np.pad(keypoints, ((0, pad_length), (0, 0)), 'constant', constant_values=0)
            elif keypoints.shape[0] > WINDOW_SIZE:
                # Truncate if more than WINDOW_SIZE
                keypoints = keypoints[:WINDOW_SIZE]

            self.data.append(keypoints)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return torch.tensor(self.data[idx], dtype=torch.float32)


def infer(model, test_data_path):
    # Set the model to evaluation mode
    model.eval()
    
    # Create an instance of the inference dataset
    test_dataset = PoseInferenceDataset(test_data_path, CLASSES)
    
    # Create a DataLoader for the test dataset
    test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False)

    predictions = []
    with torch.no_grad():
        for batch in test_dataloader:
            y_pred = model(batch)  # Forward pass
            preds = torch.argmax(y_pred, dim=1)  # Get predicted class
            predictions.extend(preds.cpu().numpy())  # Collect predictions

    return predictions


# Example usage
TEST_DATA_PATH = r"D:\gaurav\shopper_mediapipe_handpose\merl_classification\merl_test_keypoint\test_merl_31toALl\action_action_4"  # Path to your test data
predictions = infer(model, TEST_DATA_PATH)

# Output predictions
for idx, pred in enumerate(predictions):
    #print(f"Test sample {idx}: Predicted class - {pred}")
    print(pred)


3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
0
3
0
3
3
3
3
3
3
3
0
3
3
3
3
3
3
3
3
3
3
1
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
0
3
3
3
3
3
3
1
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
2
3
3
1
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
1
3
3
3
3
3
3
3
3
3
3
1
3
3
1
3
3
3
3
3
3
3
3
3
3
3
0
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
1
3
1
3
3
3
3
3
3
3


In [6]:
number = [
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 3, 0, 0, 3, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1
]



count = number.count(0)
print(f'{count}')
print(len(number))

224
246


In [8]:
number = [
    1, 3, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 2, 1, 3, 1, 2, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 3, 1, 1, 1, 1, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 3, 3, 2, 1, 1, 1, 2, 1, 2, 2
]


count = number.count(1)
print(f'{count}')
print(len(number))

198
233


In [11]:
number = [
    2, 2, 1, 0, 2, 2, 2, 0, 2, 2, 0, 2, 1, 2, 2, 2, 0, 2, 2, 1, 2, 2, 0, 2, 3, 2, 3, 2, 2, 1, 2, 3, 0, 1, 2, 2, 2, 2, 2, 0, 2, 1, 2, 2, 2, 0, 2, 3, 2, 1, 1, 2, 2, 2, 3, 0, 2, 1, 2, 2, 2, 2, 2, 2, 1, 2, 1, 2, 3, 2, 0, 3, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 0, 2, 2, 3, 0, 2, 2, 2, 1, 2, 2, 2, 1, 2
]

count = number.count(2)
print(f'{count}')
print(len(number))

61
96


In [13]:
number = [
    3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 3, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 3, 3, 1, 3, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 3, 0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 1, 1, 3, 3, 3, 3, 0, 3, 1, 3, 3, 3, 3, 3, 1, 3, 1, 3, 2, 3, 1, 0, 3, 3, 3, 3, 3, 3, 0, 3, 3, 3, 3, 0, 3, 3, 1, 3, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3
]

count = number.count(3)
print(f'{count}')
print(len(number))

143
168


In [15]:
number = [
    3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 3, 0, 3, 3, 3, 3, 3, 3, 3, 0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 3, 3, 3, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 3, 1, 3, 3, 3, 3, 3, 3, 3
]

count = number.count(3)
print(f'{count}')
print(len(number))

204
218


In [20]:
model.eval()

ActionClassificationBiLSTM(
  (lstm): LSTM(44, 64, num_layers=3, batch_first=True, dropout=0.3, bidirectional=True)
  (fc1): Linear(in_features=128, out_features=64, bias=True)
  (batch_norm_fc1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (attention): Linear(in_features=128, out_features=1, bias=True)
  (fc2): Linear(in_features=64, out_features=4, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
)

In [16]:
##save the model

import torch

def save_model(model, path):
    """Saves a PyTorch model to a file."""
    torch.save(model.state_dict(), path)


save_path = r"D:\gaurav\shopper_mediapipe_handpose\merl_classification\BiLSTM_64neuron_4class_best_till.pt"  # Replace with your desired path
save_model(model, save_path)
print(f"Model saved to {save_path}")

Model saved to D:\gaurav\shopper_mediapipe_handpose\merl_classification\BiLSTM_64neuron_4class_best_till.pt
