In [5]:
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score, classification_report
import pandas as pd
from social_lstm_classifier import SocialLSTMClassifier

trajectory_features = ['Positionx', 'Positiony', 'Distance', 'Speed', 'Speed Change', 'Direction', 'Direction Change']

feature_combos = {
    # 'xy_2': ['Positionx', 'Positiony'],
    # 'xyd_3': ['Positionx', 'Positiony', 'Distance'],
    'xydsdc_7': ['Positionx', 'Positiony', 'Distance', 'Speed', 'Direction', 'Speed Change', 'Direction Change'],
    # 'xysd_4': ['Positionx', 'Positiony', 'Speed', 'Direction'],
    # 'xyds_4': ['Positionx', 'Positiony', 'Distance', 'Speed'],
    # 'xydd_4': ['Positionx', 'Positiony', 'Distance', 'Direction'],
    # 'xydsc_5': ['Positionx', 'Positiony', 'Distance', 'Speed', 'Speed Change'],
    # 'xyddc_5': ['Positionx', 'Positiony', 'Distance', 'Direction', 'Direction Change'],
    # 'xys_3': ['Positionx', 'Positiony', 'Speed'],
    'xydi_3': ['Positionx', 'Positiony', 'Direction'],
    # 'xysc_4': ['Positionx', 'Positiony', 'Speed', 'Speed Change'],
    'xydc_4': ['Positionx', 'Positiony', 'Direction', 'Direction Change']
}

train_data = torch.load("/Users/anzhunie/Desktop/Pedestrian_Training/Prediction/train_social_lstm_full.pt")
test_data = torch.load("/Users/anzhunie/Desktop/Pedestrian_Training/Prediction/test_social_lstm_full.pt")

results = []

def run_experiment(feature_set, name):
    input_size = len(feature_set)
    index_map = [trajectory_features.index(f) for f in feature_set]

    model = SocialLSTMClassifier(input_size=input_size)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    epochs = 50

    model.train()
    for epoch in range(epochs):
        total_loss = 0.0
        for sample in tqdm(train_data, desc=f"Epoch {epoch+1}/{epochs}"):
            traj = sample['trajectory'][:, index_map].unsqueeze(1)
            neighbors = sample['neighbors'][:, :, index_map]
            mask = sample['neighbor_mask']
            label = torch.tensor([int(sample['cluster']) - 1], dtype=torch.long)

            optimizer.zero_grad()
            logits = model(traj, neighbors, mask)
            loss = criterion(logits, label)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_data)
        print(f"\n[Feature Set: {name}] Epoch [{epoch+1}/{epochs}] - Avg Loss: {avg_loss:.4f}")


    model.eval()
    y_true = []
    y_pred = []
    with torch.no_grad():
       for sample in tqdm(train_data, desc=f"Epoch {epoch+1}/{epochs}"):
            traj = sample['trajectory'][:, index_map].unsqueeze(1)
            neighbors = sample['neighbors'][:, :, index_map]
            mask = sample['neighbor_mask']
            label = int(sample['cluster']) - 1
            logits = model(traj, neighbors, mask)
            pred = logits.argmax(dim=1).item()
            y_true.append(label)
            y_pred.append(pred)

    acc = accuracy_score(y_true, y_pred)
    clf_report = classification_report(y_true, y_pred, output_dict=True)
    results.append({
        "Feature_Set": name,
        "Accuracy": acc,
        "Macro_F1": clf_report['macro avg']['f1-score'],
        "Weighted_F1": clf_report['weighted avg']['f1-score']
    })

for name, feats in feature_combos.items():
    run_experiment(feats, name)



  train_data = torch.load("/Users/anzhunie/Desktop/Pedestrian_Training/Prediction/train_social_lstm_full.pt")
  test_data = torch.load("/Users/anzhunie/Desktop/Pedestrian_Training/Prediction/test_social_lstm_full.pt")
Epoch 1/50: 100%|██████████| 4920/4920 [00:32<00:00, 150.89it/s]



[Feature Set: xydsdc_7] Epoch [1/50] - Avg Loss: 0.6447


Epoch 2/50: 100%|██████████| 4920/4920 [00:31<00:00, 155.87it/s]



[Feature Set: xydsdc_7] Epoch [2/50] - Avg Loss: 0.6139


Epoch 3/50: 100%|██████████| 4920/4920 [00:31<00:00, 154.56it/s]



[Feature Set: xydsdc_7] Epoch [3/50] - Avg Loss: 0.5896


Epoch 4/50: 100%|██████████| 4920/4920 [00:31<00:00, 155.02it/s]



[Feature Set: xydsdc_7] Epoch [4/50] - Avg Loss: 0.5626


Epoch 5/50: 100%|██████████| 4920/4920 [00:31<00:00, 154.05it/s]



[Feature Set: xydsdc_7] Epoch [5/50] - Avg Loss: 0.5326


Epoch 6/50: 100%|██████████| 4920/4920 [00:32<00:00, 152.26it/s]



[Feature Set: xydsdc_7] Epoch [6/50] - Avg Loss: 0.5067


Epoch 7/50: 100%|██████████| 4920/4920 [00:31<00:00, 157.27it/s]



[Feature Set: xydsdc_7] Epoch [7/50] - Avg Loss: 0.4680


Epoch 8/50: 100%|██████████| 4920/4920 [00:32<00:00, 151.60it/s]



[Feature Set: xydsdc_7] Epoch [8/50] - Avg Loss: 0.4246


Epoch 9/50: 100%|██████████| 4920/4920 [00:31<00:00, 157.31it/s]



[Feature Set: xydsdc_7] Epoch [9/50] - Avg Loss: 0.3946


Epoch 10/50: 100%|██████████| 4920/4920 [00:30<00:00, 158.81it/s]



[Feature Set: xydsdc_7] Epoch [10/50] - Avg Loss: 0.3704


Epoch 11/50: 100%|██████████| 4920/4920 [00:32<00:00, 153.29it/s]



[Feature Set: xydsdc_7] Epoch [11/50] - Avg Loss: 0.3265


Epoch 12/50: 100%|██████████| 4920/4920 [00:30<00:00, 159.65it/s]



[Feature Set: xydsdc_7] Epoch [12/50] - Avg Loss: 0.3172


Epoch 13/50: 100%|██████████| 4920/4920 [00:30<00:00, 161.08it/s]



[Feature Set: xydsdc_7] Epoch [13/50] - Avg Loss: 0.2865


Epoch 14/50: 100%|██████████| 4920/4920 [00:31<00:00, 156.36it/s]



[Feature Set: xydsdc_7] Epoch [14/50] - Avg Loss: 0.2702


Epoch 15/50: 100%|██████████| 4920/4920 [00:31<00:00, 154.14it/s]



[Feature Set: xydsdc_7] Epoch [15/50] - Avg Loss: 0.2541


Epoch 16/50: 100%|██████████| 4920/4920 [00:32<00:00, 153.75it/s]



[Feature Set: xydsdc_7] Epoch [16/50] - Avg Loss: 0.2431


Epoch 17/50: 100%|██████████| 4920/4920 [00:31<00:00, 157.02it/s]



[Feature Set: xydsdc_7] Epoch [17/50] - Avg Loss: 0.2155


Epoch 18/50: 100%|██████████| 4920/4920 [00:31<00:00, 155.80it/s]



[Feature Set: xydsdc_7] Epoch [18/50] - Avg Loss: 0.2142


Epoch 19/50: 100%|██████████| 4920/4920 [00:32<00:00, 153.07it/s]



[Feature Set: xydsdc_7] Epoch [19/50] - Avg Loss: 0.2147


Epoch 20/50: 100%|██████████| 4920/4920 [00:31<00:00, 158.64it/s]



[Feature Set: xydsdc_7] Epoch [20/50] - Avg Loss: 0.1960


Epoch 21/50: 100%|██████████| 4920/4920 [00:32<00:00, 153.29it/s]



[Feature Set: xydsdc_7] Epoch [21/50] - Avg Loss: 0.2051


Epoch 22/50: 100%|██████████| 4920/4920 [00:30<00:00, 160.35it/s]



[Feature Set: xydsdc_7] Epoch [22/50] - Avg Loss: 0.1823


Epoch 23/50: 100%|██████████| 4920/4920 [00:30<00:00, 160.73it/s]



[Feature Set: xydsdc_7] Epoch [23/50] - Avg Loss: 0.1744


Epoch 24/50: 100%|██████████| 4920/4920 [00:31<00:00, 156.40it/s]



[Feature Set: xydsdc_7] Epoch [24/50] - Avg Loss: 0.1719


Epoch 25/50: 100%|██████████| 4920/4920 [00:31<00:00, 156.39it/s]



[Feature Set: xydsdc_7] Epoch [25/50] - Avg Loss: 0.1559


Epoch 26/50: 100%|██████████| 4920/4920 [00:31<00:00, 153.94it/s]



[Feature Set: xydsdc_7] Epoch [26/50] - Avg Loss: 0.1551


Epoch 27/50: 100%|██████████| 4920/4920 [00:32<00:00, 152.46it/s]



[Feature Set: xydsdc_7] Epoch [27/50] - Avg Loss: 0.1551


Epoch 28/50: 100%|██████████| 4920/4920 [00:31<00:00, 158.68it/s]



[Feature Set: xydsdc_7] Epoch [28/50] - Avg Loss: 0.1530


Epoch 29/50: 100%|██████████| 4920/4920 [00:32<00:00, 149.54it/s]



[Feature Set: xydsdc_7] Epoch [29/50] - Avg Loss: 0.1519


Epoch 30/50: 100%|██████████| 4920/4920 [00:31<00:00, 157.62it/s]



[Feature Set: xydsdc_7] Epoch [30/50] - Avg Loss: 0.1405


Epoch 31/50: 100%|██████████| 4920/4920 [00:33<00:00, 147.33it/s]



[Feature Set: xydsdc_7] Epoch [31/50] - Avg Loss: 0.1312


Epoch 32/50: 100%|██████████| 4920/4920 [00:30<00:00, 159.85it/s]



[Feature Set: xydsdc_7] Epoch [32/50] - Avg Loss: 0.1437


Epoch 33/50: 100%|██████████| 4920/4920 [00:30<00:00, 161.29it/s]



[Feature Set: xydsdc_7] Epoch [33/50] - Avg Loss: 0.1342


Epoch 34/50: 100%|██████████| 4920/4920 [00:30<00:00, 162.53it/s]



[Feature Set: xydsdc_7] Epoch [34/50] - Avg Loss: 0.1304


Epoch 35/50: 100%|██████████| 4920/4920 [00:30<00:00, 162.36it/s]



[Feature Set: xydsdc_7] Epoch [35/50] - Avg Loss: 0.1161


Epoch 36/50: 100%|██████████| 4920/4920 [00:30<00:00, 161.59it/s]



[Feature Set: xydsdc_7] Epoch [36/50] - Avg Loss: 0.1190


Epoch 37/50: 100%|██████████| 4920/4920 [00:30<00:00, 161.92it/s]



[Feature Set: xydsdc_7] Epoch [37/50] - Avg Loss: 0.1167


Epoch 38/50: 100%|██████████| 4920/4920 [00:30<00:00, 160.62it/s]



[Feature Set: xydsdc_7] Epoch [38/50] - Avg Loss: 0.1049


Epoch 39/50: 100%|██████████| 4920/4920 [00:32<00:00, 150.08it/s]



[Feature Set: xydsdc_7] Epoch [39/50] - Avg Loss: 0.1094


Epoch 40/50: 100%|██████████| 4920/4920 [00:31<00:00, 155.78it/s]



[Feature Set: xydsdc_7] Epoch [40/50] - Avg Loss: 0.1074


Epoch 41/50: 100%|██████████| 4920/4920 [00:31<00:00, 158.64it/s]



[Feature Set: xydsdc_7] Epoch [41/50] - Avg Loss: 0.1206


Epoch 42/50: 100%|██████████| 4920/4920 [00:31<00:00, 157.39it/s]



[Feature Set: xydsdc_7] Epoch [42/50] - Avg Loss: 0.1040


Epoch 43/50: 100%|██████████| 4920/4920 [00:31<00:00, 155.56it/s]



[Feature Set: xydsdc_7] Epoch [43/50] - Avg Loss: 0.0991


Epoch 44/50: 100%|██████████| 4920/4920 [00:31<00:00, 158.02it/s]



[Feature Set: xydsdc_7] Epoch [44/50] - Avg Loss: 0.1112


Epoch 45/50: 100%|██████████| 4920/4920 [00:32<00:00, 153.62it/s]



[Feature Set: xydsdc_7] Epoch [45/50] - Avg Loss: 0.0935


Epoch 46/50: 100%|██████████| 4920/4920 [00:31<00:00, 154.55it/s]



[Feature Set: xydsdc_7] Epoch [46/50] - Avg Loss: 0.1243


Epoch 47/50: 100%|██████████| 4920/4920 [00:31<00:00, 157.62it/s]



[Feature Set: xydsdc_7] Epoch [47/50] - Avg Loss: 0.0944


Epoch 48/50: 100%|██████████| 4920/4920 [00:31<00:00, 157.24it/s]



[Feature Set: xydsdc_7] Epoch [48/50] - Avg Loss: 0.1027


Epoch 49/50: 100%|██████████| 4920/4920 [00:32<00:00, 151.56it/s]



[Feature Set: xydsdc_7] Epoch [49/50] - Avg Loss: 0.0923


Epoch 50/50: 100%|██████████| 4920/4920 [00:33<00:00, 146.30it/s]



[Feature Set: xydsdc_7] Epoch [50/50] - Avg Loss: 0.1104


Epoch 50/50: 100%|██████████| 4920/4920 [00:14<00:00, 350.40it/s]
Epoch 1/50: 100%|██████████| 4920/4920 [00:31<00:00, 157.68it/s]



[Feature Set: xydi_3] Epoch [1/50] - Avg Loss: 0.6431


Epoch 2/50: 100%|██████████| 4920/4920 [00:31<00:00, 154.97it/s]



[Feature Set: xydi_3] Epoch [2/50] - Avg Loss: 0.6039


Epoch 3/50: 100%|██████████| 4920/4920 [00:31<00:00, 157.94it/s]



[Feature Set: xydi_3] Epoch [3/50] - Avg Loss: 0.5805


Epoch 4/50: 100%|██████████| 4920/4920 [00:31<00:00, 158.22it/s]



[Feature Set: xydi_3] Epoch [4/50] - Avg Loss: 0.5654


Epoch 5/50: 100%|██████████| 4920/4920 [00:30<00:00, 158.91it/s]



[Feature Set: xydi_3] Epoch [5/50] - Avg Loss: 0.5385


Epoch 6/50: 100%|██████████| 4920/4920 [00:30<00:00, 158.72it/s]



[Feature Set: xydi_3] Epoch [6/50] - Avg Loss: 0.5166


Epoch 7/50: 100%|██████████| 4920/4920 [00:31<00:00, 158.17it/s]



[Feature Set: xydi_3] Epoch [7/50] - Avg Loss: 0.4884


Epoch 8/50: 100%|██████████| 4920/4920 [00:30<00:00, 158.91it/s]



[Feature Set: xydi_3] Epoch [8/50] - Avg Loss: 0.4607


Epoch 9/50: 100%|██████████| 4920/4920 [00:31<00:00, 156.90it/s]



[Feature Set: xydi_3] Epoch [9/50] - Avg Loss: 0.4424


Epoch 10/50: 100%|██████████| 4920/4920 [00:31<00:00, 157.45it/s]



[Feature Set: xydi_3] Epoch [10/50] - Avg Loss: 0.4139


Epoch 11/50: 100%|██████████| 4920/4920 [00:35<00:00, 137.42it/s]



[Feature Set: xydi_3] Epoch [11/50] - Avg Loss: 0.3949


Epoch 12/50: 100%|██████████| 4920/4920 [00:34<00:00, 143.94it/s]



[Feature Set: xydi_3] Epoch [12/50] - Avg Loss: 0.3700


Epoch 13/50: 100%|██████████| 4920/4920 [00:34<00:00, 142.02it/s]



[Feature Set: xydi_3] Epoch [13/50] - Avg Loss: 0.3627


Epoch 14/50: 100%|██████████| 4920/4920 [00:31<00:00, 156.36it/s]



[Feature Set: xydi_3] Epoch [14/50] - Avg Loss: 0.3482


Epoch 15/50: 100%|██████████| 4920/4920 [00:31<00:00, 158.55it/s]



[Feature Set: xydi_3] Epoch [15/50] - Avg Loss: 0.3251


Epoch 16/50: 100%|██████████| 4920/4920 [00:31<00:00, 155.76it/s]



[Feature Set: xydi_3] Epoch [16/50] - Avg Loss: 0.3229


Epoch 17/50: 100%|██████████| 4920/4920 [00:30<00:00, 159.37it/s]



[Feature Set: xydi_3] Epoch [17/50] - Avg Loss: 0.2975


Epoch 18/50: 100%|██████████| 4920/4920 [00:31<00:00, 155.59it/s]



[Feature Set: xydi_3] Epoch [18/50] - Avg Loss: 0.2838


Epoch 19/50: 100%|██████████| 4920/4920 [00:30<00:00, 161.10it/s]



[Feature Set: xydi_3] Epoch [19/50] - Avg Loss: 0.2788


Epoch 20/50: 100%|██████████| 4920/4920 [00:30<00:00, 161.79it/s]



[Feature Set: xydi_3] Epoch [20/50] - Avg Loss: 0.2546


Epoch 21/50: 100%|██████████| 4920/4920 [00:30<00:00, 162.29it/s]



[Feature Set: xydi_3] Epoch [21/50] - Avg Loss: 0.2688


Epoch 22/50: 100%|██████████| 4920/4920 [00:30<00:00, 160.39it/s]



[Feature Set: xydi_3] Epoch [22/50] - Avg Loss: 0.2477


Epoch 23/50: 100%|██████████| 4920/4920 [00:30<00:00, 162.17it/s]



[Feature Set: xydi_3] Epoch [23/50] - Avg Loss: 0.2288


Epoch 24/50: 100%|██████████| 4920/4920 [00:30<00:00, 160.08it/s]



[Feature Set: xydi_3] Epoch [24/50] - Avg Loss: 0.2287


Epoch 25/50: 100%|██████████| 4920/4920 [00:30<00:00, 162.81it/s]



[Feature Set: xydi_3] Epoch [25/50] - Avg Loss: 0.2263


Epoch 26/50: 100%|██████████| 4920/4920 [00:30<00:00, 158.78it/s]



[Feature Set: xydi_3] Epoch [26/50] - Avg Loss: 0.2115


Epoch 27/50: 100%|██████████| 4920/4920 [00:31<00:00, 157.64it/s]



[Feature Set: xydi_3] Epoch [27/50] - Avg Loss: 0.2262


Epoch 28/50: 100%|██████████| 4920/4920 [00:33<00:00, 147.80it/s]



[Feature Set: xydi_3] Epoch [28/50] - Avg Loss: 0.1987


Epoch 29/50: 100%|██████████| 4920/4920 [00:32<00:00, 149.88it/s]



[Feature Set: xydi_3] Epoch [29/50] - Avg Loss: 0.1975


Epoch 30/50: 100%|██████████| 4920/4920 [00:31<00:00, 157.71it/s]



[Feature Set: xydi_3] Epoch [30/50] - Avg Loss: 0.2045


Epoch 31/50: 100%|██████████| 4920/4920 [00:31<00:00, 154.40it/s]



[Feature Set: xydi_3] Epoch [31/50] - Avg Loss: 0.1869


Epoch 32/50: 100%|██████████| 4920/4920 [00:30<00:00, 160.49it/s]



[Feature Set: xydi_3] Epoch [32/50] - Avg Loss: 0.1865


Epoch 33/50: 100%|██████████| 4920/4920 [00:30<00:00, 160.53it/s]



[Feature Set: xydi_3] Epoch [33/50] - Avg Loss: 0.2033


Epoch 34/50: 100%|██████████| 4920/4920 [00:30<00:00, 161.79it/s]



[Feature Set: xydi_3] Epoch [34/50] - Avg Loss: 0.1588


Epoch 35/50: 100%|██████████| 4920/4920 [00:31<00:00, 156.63it/s]



[Feature Set: xydi_3] Epoch [35/50] - Avg Loss: 0.1774


Epoch 36/50: 100%|██████████| 4920/4920 [00:32<00:00, 151.84it/s]



[Feature Set: xydi_3] Epoch [36/50] - Avg Loss: 0.1934


Epoch 37/50: 100%|██████████| 4920/4920 [00:31<00:00, 153.85it/s]



[Feature Set: xydi_3] Epoch [37/50] - Avg Loss: 0.1599


Epoch 38/50: 100%|██████████| 4920/4920 [00:33<00:00, 148.48it/s]



[Feature Set: xydi_3] Epoch [38/50] - Avg Loss: 0.1616


Epoch 39/50: 100%|██████████| 4920/4920 [00:31<00:00, 157.56it/s]



[Feature Set: xydi_3] Epoch [39/50] - Avg Loss: 0.1497


Epoch 40/50: 100%|██████████| 4920/4920 [00:30<00:00, 160.08it/s]



[Feature Set: xydi_3] Epoch [40/50] - Avg Loss: 0.1661


Epoch 41/50: 100%|██████████| 4920/4920 [00:32<00:00, 153.19it/s]



[Feature Set: xydi_3] Epoch [41/50] - Avg Loss: 0.1488


Epoch 42/50: 100%|██████████| 4920/4920 [00:32<00:00, 152.39it/s]



[Feature Set: xydi_3] Epoch [42/50] - Avg Loss: 0.1480


Epoch 43/50: 100%|██████████| 4920/4920 [00:31<00:00, 158.55it/s]



[Feature Set: xydi_3] Epoch [43/50] - Avg Loss: 0.1595


Epoch 44/50: 100%|██████████| 4920/4920 [00:30<00:00, 159.55it/s]



[Feature Set: xydi_3] Epoch [44/50] - Avg Loss: 0.1498


Epoch 45/50: 100%|██████████| 4920/4920 [00:31<00:00, 155.58it/s]



[Feature Set: xydi_3] Epoch [45/50] - Avg Loss: 0.1452


Epoch 46/50: 100%|██████████| 4920/4920 [00:30<00:00, 160.02it/s]



[Feature Set: xydi_3] Epoch [46/50] - Avg Loss: 0.1510


Epoch 47/50: 100%|██████████| 4920/4920 [00:31<00:00, 154.83it/s]



[Feature Set: xydi_3] Epoch [47/50] - Avg Loss: 0.1361


Epoch 48/50: 100%|██████████| 4920/4920 [00:32<00:00, 153.40it/s]



[Feature Set: xydi_3] Epoch [48/50] - Avg Loss: 0.1321


Epoch 49/50: 100%|██████████| 4920/4920 [00:31<00:00, 158.38it/s]



[Feature Set: xydi_3] Epoch [49/50] - Avg Loss: 0.1656


Epoch 50/50: 100%|██████████| 4920/4920 [00:31<00:00, 156.48it/s]



[Feature Set: xydi_3] Epoch [50/50] - Avg Loss: 0.1274


Epoch 50/50: 100%|██████████| 4920/4920 [00:14<00:00, 343.98it/s]
Epoch 1/50: 100%|██████████| 4920/4920 [00:31<00:00, 158.67it/s]



[Feature Set: xydc_4] Epoch [1/50] - Avg Loss: 0.6427


Epoch 2/50: 100%|██████████| 4920/4920 [00:32<00:00, 151.92it/s]



[Feature Set: xydc_4] Epoch [2/50] - Avg Loss: 0.6125


Epoch 3/50: 100%|██████████| 4920/4920 [00:32<00:00, 153.71it/s]



[Feature Set: xydc_4] Epoch [3/50] - Avg Loss: 0.5946


Epoch 4/50: 100%|██████████| 4920/4920 [00:31<00:00, 155.98it/s]



[Feature Set: xydc_4] Epoch [4/50] - Avg Loss: 0.5713


Epoch 5/50: 100%|██████████| 4920/4920 [00:30<00:00, 160.53it/s]



[Feature Set: xydc_4] Epoch [5/50] - Avg Loss: 0.5475


Epoch 6/50: 100%|██████████| 4920/4920 [00:33<00:00, 145.63it/s]



[Feature Set: xydc_4] Epoch [6/50] - Avg Loss: 0.5139


Epoch 7/50: 100%|██████████| 4920/4920 [00:32<00:00, 153.14it/s]



[Feature Set: xydc_4] Epoch [7/50] - Avg Loss: 0.4825


Epoch 8/50: 100%|██████████| 4920/4920 [00:31<00:00, 155.91it/s]



[Feature Set: xydc_4] Epoch [8/50] - Avg Loss: 0.4538


Epoch 9/50: 100%|██████████| 4920/4920 [00:30<00:00, 160.39it/s]



[Feature Set: xydc_4] Epoch [9/50] - Avg Loss: 0.4237


Epoch 10/50: 100%|██████████| 4920/4920 [00:31<00:00, 158.59it/s]



[Feature Set: xydc_4] Epoch [10/50] - Avg Loss: 0.3855


Epoch 11/50: 100%|██████████| 4920/4920 [00:30<00:00, 160.37it/s]



[Feature Set: xydc_4] Epoch [11/50] - Avg Loss: 0.3624


Epoch 12/50: 100%|██████████| 4920/4920 [00:30<00:00, 159.64it/s]



[Feature Set: xydc_4] Epoch [12/50] - Avg Loss: 0.3353


Epoch 13/50: 100%|██████████| 4920/4920 [00:30<00:00, 160.83it/s]



[Feature Set: xydc_4] Epoch [13/50] - Avg Loss: 0.3095


Epoch 14/50: 100%|██████████| 4920/4920 [00:30<00:00, 160.48it/s]



[Feature Set: xydc_4] Epoch [14/50] - Avg Loss: 0.2987


Epoch 15/50: 100%|██████████| 4920/4920 [00:30<00:00, 159.92it/s]



[Feature Set: xydc_4] Epoch [15/50] - Avg Loss: 0.2763


Epoch 16/50: 100%|██████████| 4920/4920 [00:31<00:00, 154.22it/s]



[Feature Set: xydc_4] Epoch [16/50] - Avg Loss: 0.2544


Epoch 17/50: 100%|██████████| 4920/4920 [00:30<00:00, 158.74it/s]



[Feature Set: xydc_4] Epoch [17/50] - Avg Loss: 0.2468


Epoch 18/50: 100%|██████████| 4920/4920 [00:30<00:00, 158.81it/s]



[Feature Set: xydc_4] Epoch [18/50] - Avg Loss: 0.2284


Epoch 19/50: 100%|██████████| 4920/4920 [00:30<00:00, 160.42it/s]



[Feature Set: xydc_4] Epoch [19/50] - Avg Loss: 0.2314


Epoch 20/50: 100%|██████████| 4920/4920 [00:30<00:00, 160.12it/s]



[Feature Set: xydc_4] Epoch [20/50] - Avg Loss: 0.2266


Epoch 21/50: 100%|██████████| 4920/4920 [00:30<00:00, 159.54it/s]



[Feature Set: xydc_4] Epoch [21/50] - Avg Loss: 0.1933


Epoch 22/50: 100%|██████████| 4920/4920 [00:30<00:00, 160.62it/s]



[Feature Set: xydc_4] Epoch [22/50] - Avg Loss: 0.1885


Epoch 23/50: 100%|██████████| 4920/4920 [00:30<00:00, 160.10it/s]



[Feature Set: xydc_4] Epoch [23/50] - Avg Loss: 0.1968


Epoch 24/50: 100%|██████████| 4920/4920 [00:30<00:00, 161.10it/s]



[Feature Set: xydc_4] Epoch [24/50] - Avg Loss: 0.1850


Epoch 25/50: 100%|██████████| 4920/4920 [00:30<00:00, 159.64it/s]



[Feature Set: xydc_4] Epoch [25/50] - Avg Loss: 0.1802


Epoch 26/50: 100%|██████████| 4920/4920 [00:31<00:00, 157.09it/s]



[Feature Set: xydc_4] Epoch [26/50] - Avg Loss: 0.1528


Epoch 27/50: 100%|██████████| 4920/4920 [00:31<00:00, 155.62it/s]



[Feature Set: xydc_4] Epoch [27/50] - Avg Loss: 0.1747


Epoch 28/50: 100%|██████████| 4920/4920 [00:30<00:00, 160.75it/s]



[Feature Set: xydc_4] Epoch [28/50] - Avg Loss: 0.1591


Epoch 29/50: 100%|██████████| 4920/4920 [00:30<00:00, 160.23it/s]



[Feature Set: xydc_4] Epoch [29/50] - Avg Loss: 0.1766


Epoch 30/50: 100%|██████████| 4920/4920 [00:30<00:00, 159.31it/s]



[Feature Set: xydc_4] Epoch [30/50] - Avg Loss: 0.1488


Epoch 31/50: 100%|██████████| 4920/4920 [00:30<00:00, 159.27it/s]



[Feature Set: xydc_4] Epoch [31/50] - Avg Loss: 0.1567


Epoch 32/50: 100%|██████████| 4920/4920 [00:30<00:00, 159.67it/s]



[Feature Set: xydc_4] Epoch [32/50] - Avg Loss: 0.1549


Epoch 33/50: 100%|██████████| 4920/4920 [00:30<00:00, 159.48it/s]



[Feature Set: xydc_4] Epoch [33/50] - Avg Loss: 0.1438


Epoch 34/50: 100%|██████████| 4920/4920 [00:30<00:00, 160.62it/s]



[Feature Set: xydc_4] Epoch [34/50] - Avg Loss: 0.1221


Epoch 35/50: 100%|██████████| 4920/4920 [00:31<00:00, 158.03it/s]



[Feature Set: xydc_4] Epoch [35/50] - Avg Loss: 0.1391


Epoch 36/50: 100%|██████████| 4920/4920 [00:31<00:00, 157.23it/s]



[Feature Set: xydc_4] Epoch [36/50] - Avg Loss: 0.1161


Epoch 37/50: 100%|██████████| 4920/4920 [00:30<00:00, 158.99it/s]



[Feature Set: xydc_4] Epoch [37/50] - Avg Loss: 0.1257


Epoch 38/50: 100%|██████████| 4920/4920 [00:30<00:00, 159.91it/s]



[Feature Set: xydc_4] Epoch [38/50] - Avg Loss: 0.1372


Epoch 39/50: 100%|██████████| 4920/4920 [00:31<00:00, 155.33it/s]



[Feature Set: xydc_4] Epoch [39/50] - Avg Loss: 0.1171


Epoch 40/50: 100%|██████████| 4920/4920 [00:30<00:00, 159.14it/s]



[Feature Set: xydc_4] Epoch [40/50] - Avg Loss: 0.1299


Epoch 41/50: 100%|██████████| 4920/4920 [00:30<00:00, 158.93it/s]



[Feature Set: xydc_4] Epoch [41/50] - Avg Loss: 0.1246


Epoch 42/50: 100%|██████████| 4920/4920 [00:31<00:00, 154.19it/s]



[Feature Set: xydc_4] Epoch [42/50] - Avg Loss: 0.1153


Epoch 43/50: 100%|██████████| 4920/4920 [00:31<00:00, 158.55it/s]



[Feature Set: xydc_4] Epoch [43/50] - Avg Loss: 0.1107


Epoch 44/50: 100%|██████████| 4920/4920 [00:30<00:00, 159.05it/s]



[Feature Set: xydc_4] Epoch [44/50] - Avg Loss: 0.1201


Epoch 45/50: 100%|██████████| 4920/4920 [00:31<00:00, 157.19it/s]



[Feature Set: xydc_4] Epoch [45/50] - Avg Loss: 0.1171


Epoch 46/50: 100%|██████████| 4920/4920 [00:30<00:00, 161.44it/s]



[Feature Set: xydc_4] Epoch [46/50] - Avg Loss: 0.0999


Epoch 47/50: 100%|██████████| 4920/4920 [00:30<00:00, 162.83it/s]



[Feature Set: xydc_4] Epoch [47/50] - Avg Loss: 0.1140


Epoch 48/50: 100%|██████████| 4920/4920 [00:30<00:00, 162.32it/s]



[Feature Set: xydc_4] Epoch [48/50] - Avg Loss: 0.1129


Epoch 49/50: 100%|██████████| 4920/4920 [00:30<00:00, 161.69it/s]



[Feature Set: xydc_4] Epoch [49/50] - Avg Loss: 0.1074


Epoch 50/50: 100%|██████████| 4920/4920 [00:30<00:00, 162.99it/s]



[Feature Set: xydc_4] Epoch [50/50] - Avg Loss: 0.1075


Epoch 50/50: 100%|██████████| 4920/4920 [00:14<00:00, 347.07it/s]


In [6]:
df_results = pd.DataFrame(results)
print(df_results)
# df_results.to_csv("/Users/anzhunie/Desktop/Pedestrian_Training/Prediction/feature_combination_results.csv", index=False)


  Feature_Set  Accuracy  Macro_F1  Weighted_F1
0    xydsdc_7  0.977642  0.976188     0.977628
1      xydi_3  0.951423  0.947710     0.951120
2      xydc_4  0.954065  0.950906     0.953954
