In [36]:
import pandas as pd
import tensorflow as tf
from tensorflow.keras.layers import Input, Embedding, Flatten, Dense, Concatenate, Multiply, Dropout, BatchNormalization, Activation
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l2
from tensorflow.keras.utils import plot_model

In [37]:
NUM_INJURY_CATEGORIES = 28

injury_counts_input = Input(shape=(NUM_INJURY_CATEGORIES,), name='injury_counts')
position_input = Input(shape=(1,), name='position')
numerical_input = Input(shape=(6,), name='numerical') # height, weight, age, forty, bench, vertical

In [38]:
x_position = Embedding(input_dim=29, output_dim=6, input_length=1)(position_input) # train encoding on position index
x_position = Flatten()(x_position)

x_injury = Dense(32, activation='relu')(injury_counts_input)
x_injury = Dense(32, activation='relu')(x_injury)
x_injury = Dense(32, activation='relu')(x_injury)

x_num = Dense(32, activation='relu')(numerical_input)
x_num = Dense(32, activation='relu')(x_num)
x_num = Dense(32, activation='relu')(x_num)

x = Concatenate()([x_position, x_num])
x = Dense(64, activation='relu')(x)

x = Concatenate()([x, x_injury])

x = Dense(512, activation='relu')(x) # hidden layers for prediction
x = Dropout(0.3)(x)
x = Dense(256)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)

output = Dense(NUM_INJURY_CATEGORIES, activation='softmax')(x)

model = Model(inputs=[injury_counts_input, position_input, numerical_input], outputs=output)
model.summary()

plot_model(model, to_file='model_architecture.png', show_shapes=True, show_layer_names=True)



You must install pydot (`pip install pydot`) for `plot_model` to work.


In [39]:
## Training
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from tensorflow.keras.utils import to_categorical
from sklearn.utils import class_weight
import numpy as np
from imblearn.over_sampling import RandomOverSampler
from tensorflow.keras.utils import to_categorical
import pickle

df = pd.read_csv('data.csv')
injury_counts_cols = [col for col in df.columns if col.startswith('prev_')]
numerical_cols = ['height', 'weight', 'age', 'forty', 'bench', 'vertical']

scaler = StandardScaler()
df[numerical_cols] = scaler.fit_transform(df[numerical_cols]) # Normalize numerical data

with open('scaler.pkl', 'wb') as f:
    pickle.dump(scaler, f)

injury_counts_data = df[injury_counts_cols].values
position_data = df['position_index']
numerical_data = df[numerical_cols].values

y = to_categorical(df["injury_index"], num_classes=NUM_INJURY_CATEGORIES)

(X_injury_counts_train, X_injury_counts_val_test, X_position_train, X_position_val_test, X_numerical_train, X_numerical_val_test, y_train, y_val_test) = train_test_split(injury_counts_data, position_data, numerical_data, y, test_size=0.3)

(X_injury_counts_val, X_injury_counts_test, X_position_val, X_position_test, X_numerical_val, X_numerical_test, y_val, y_test) = train_test_split(X_injury_counts_val_test, X_position_val_test, X_numerical_val_test, y_val_test, test_size=0.15)


optimizer = tf.keras.optimizers.Adam(learning_rate=0.0005)

model.compile(optimizer=optimizer, 
              loss='categorical_crossentropy', 
              metrics=['accuracy'])

ros = RandomOverSampler()

# Some classes are underrepresented. Must do this to reduce bias
X_position_train = np.array(X_position_train).reshape(-1, 1)
X_train_combined = np.hstack([X_injury_counts_train, X_position_train, X_numerical_train])
y_train_indices = np.argmax(y_train, axis=1)
X_resampled, y_resampled = ros.fit_resample(X_train_combined, y_train_indices)
X_injury_counts_train = X_resampled[:, :NUM_INJURY_CATEGORIES]
X_position_train = X_resampled[:, NUM_INJURY_CATEGORIES:NUM_INJURY_CATEGORIES+1]
X_numerical_train = X_resampled[:, NUM_INJURY_CATEGORIES+1:]

y_train = to_categorical(y_resampled, num_classes=NUM_INJURY_CATEGORIES)

In [40]:
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping

reduce_lr = ReduceLROnPlateau(
    monitor='val_accuracy',
    factor=0.5,
    patience=10,
    min_lr=1e-6,
    verbose=1
)

early_stopping = EarlyStopping(
    monitor='val_accuracy', 
    patience=20, 
    restore_best_weights=True
)

model.fit(
    x=[X_injury_counts_train, X_position_train, X_numerical_train],
    y=y_train,
    validation_data=([X_injury_counts_val, X_position_val, X_numerical_val], y_val),
    epochs=50,
    batch_size=128,
    callbacks=[reduce_lr, early_stopping]
)

model.save('tf_model.keras')

Epoch 1/50
[1m264/264[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 4ms/step - accuracy: 0.2845 - loss: 2.7070 - val_accuracy: 0.2724 - val_loss: 2.9811 - learning_rate: 5.0000e-04
Epoch 2/50
[1m264/264[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - accuracy: 0.6102 - loss: 1.4751 - val_accuracy: 0.3174 - val_loss: 2.4398 - learning_rate: 5.0000e-04
Epoch 3/50
[1m217/264[0m [32m━━━━━━━━━━━━━━━━[0m[37m━━━━[0m [1m0s[0m 3ms/step - accuracy: 0.7006 - loss: 1.0805

KeyboardInterrupt: 

In [None]:
# Test

test_loss, test_acc = model.evaluate(
    x=[X_injury_counts_test, X_position_test, X_numerical_test],
    y=y_test,
    verbose=1
)

print('Test Loss:', test_loss)
print('Test Accuracy:', test_acc)

[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - accuracy: 0.4275 - loss: 3.8907 
Test Loss: 3.6563382148742676
Test Accuracy: 0.45501285791397095


In [24]:
y_test

array([[0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [1., 0., 0., ..., 0., 0., 0.],
       [1., 0., 0., ..., 0., 0., 0.]])

In [None]:
model.predict([X_injury_counts_test, X_position_test, X_numerical_test])[0]

X_numerical_test

[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step 


array([[-1.33735599, -0.79805714, -0.02715737, -0.62592931,  1.83361952,
        -0.05363306],
       [ 2.3530824 ,  1.95647683, -0.80784118,  1.16639206,  0.17330845,
        -0.58148645],
       [-0.59926831,  1.61820073, -1.19818308,  1.54773704,  1.66758841,
        -0.84541314],
       ...,
       [-0.59926831, -1.08800808, -0.41749927, -1.00727428, -1.81906485,
         1.39796375],
       [ 1.24595088,  1.59403815, -0.41749927,  0.21302963,  1.83361952,
         0.07833028],
       [-0.96831215, -0.26648041, -0.41749927, -0.13018084, -0.15875377,
         0.07833028]])