In [None]:
import pandas as pd
from torch.utils.data import Dataset
import glob
from sklearn.model_selection import StratifiedShuffleSplit
from torch.utils.data import DataLoader
from torch import nn
import torch
from tqdm import tqdm

In [None]:
import wandb

In [None]:
class HandLandmarksDataset(Dataset):
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        X = torch.from_numpy(self.X[idx, :])
        Y = self.Y[idx]
        return X, Y

In [None]:
img_labels = pd.read_csv("../data_collection/data/gestures.csv")
columns = list(img_labels.drop(columns=["filename", "label"]).columns)

In [None]:
X = img_labels[columns].values.astype("float32")

In [None]:
Y = img_labels.label.values

In [None]:
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.15, random_state=42)

In [None]:
train_index, test_index = list(sss.split(X, Y))[0]

In [None]:
X_train = X[train_index, :]
Y_train = Y[train_index]
X_test = X[test_index, :]
Y_test = Y[test_index]

In [None]:
sss = StratifiedShuffleSplit(n_splits=1, test_size=Y_test.shape[0]/X_train.shape[0], random_state=42)

In [None]:
train_index, valid_index = list(sss.split(X_train, Y_train))[0]

In [None]:
X_train = X[train_index, :]
Y_train = Y[train_index]
X_valid = X[valid_index, :]
Y_valid = Y[valid_index]

In [None]:
training_data = HandLandmarksDataset(X_train, Y_train)
test_data = HandLandmarksDataset(X_valid, Y_valid)
valid_data = HandLandmarksDataset(X_test, Y_test)

In [None]:
train_dataloader = DataLoader(training_data, batch_size=512, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=128, shuffle=True)

In [None]:
valid_dataloader = DataLoader(valid_data, batch_size=512, shuffle=True)

In [None]:
class HandModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(63, 21)
        nn.init.xavier_uniform_(self.linear1.weight)
        self.linear2 = nn.Linear(21, 6)
        nn.init.xavier_uniform_(self.linear2.weight)

        
    def forward(self, x):
        logits = self.linear2(nn.functional.leaky_relu(self.linear1(x)))

        return logits

In [None]:
model = HandModel()

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
wandb.init(project="multimodal", entity="mazza")
wandb.config = {
  "learning_rate": 0.001,
  "epochs": 5000,
  "batch_size": 512
}
wandb.watch(model)


In [None]:
N_EPOCHS = 5000
EARLY_STOPPING = 500
best_loss = 5
stopping_count = 0
for epoch in tqdm(range(N_EPOCHS)):
    for i, (inputs, targets) in tqdm(enumerate(train_dataloader)):
        optimizer.zero_grad()
        yhat = model(inputs)
        loss = criterion(yhat, targets)
        loss.backward()
        optimizer.step()
    model.eval()
    with torch.no_grad():
        for X, y in valid_dataloader:
            pred = model(X)
            valid_loss = criterion(pred, y).item()
        if valid_loss < best_loss:
            best_loss = valid_loss
            stopping_count = 0
            
            torch.onnx.export(model,               
                              dummy_input,                         
                              "super_resolution.onnx",   
                              export_params=True,      
                              opset_version=10,         
                              do_constant_folding=True,  
                              input_names = ['input'],   
                              output_names = ['output'],
                              dynamic_axes={'input' : {0 : 'batch_size'},    
                                            'output' : {0 : 'batch_size'}})
        else:
            stopping_count += 1
        if stopping_count > EARLY_STOPPING:
            break
    wandb.log({"valid_loss": valid_loss, "last_batch_train_loss": loss})

wandb: Network error (ConnectionError), entering retry loop.
