In [None]:
import torch 
import numpy as np
import pandas as pd 
from torchvision import transforms,datasets 
from torch.utils.data import DataLoader, Dataset
from model import PoseClassification

dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device : {dev}')

In [None]:
class csv_data(Dataset):

    def __init__(self, csv_path, transform = None):
        self.arr = np.array(pd.read_csv(csv_path))
        self.len = self.arr.shape[0]
        self.transform = transform

    def __getitem__(self, index):
        xvals = self.arr[index][1:]
        yvals = self.arr[index][0]
        xvals = torch.from_numpy(xvals).to(torch.float32)

        if self.transform :
            image = self.transform(image)
        return xvals, yvals

    def __len__(self):
        return self.len

In [None]:
move2id = {'no_move':0, 'hook':1, 'uppercut':2, 'special':3, 'kick':4}
id2move = {x:y for y,x in move2id.items()}

In [None]:
BATCH_SIZE = 10
IN_FEATURES = 19 * 3
OUT_FEATURES = len(move2id)

In [None]:
dataset = csv_data('/Users/gursi/Desktop/Pose2Play/move_dataset/dataset.csv')
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
model = PoseClassification(IN_FEATURES, OUT_FEATURES).to(dev)

In [None]:
LR = 0.00001
EPOCHS = 30

opt = torch.optim.Adam(model.parameters(), lr=LR)
crit = torch.nn.CrossEntropyLoss()

total_batch_loss = 0
for e in range(EPOCHS) : 
    print(f'Epoch : {e}')
    print()
    for batch_id, (coords,labels) in enumerate(loader) : 

        coords, labels = coords.to(dev), labels.to(dev)
        labels = labels.to(torch.int64)

        opt.zero_grad()
        yhat = model(coords)
        loss = crit(yhat, labels)
        loss.backward()
        opt.step()

        total_batch_loss += loss.item()
        if batch_id % 50 == 0 : 
            print(f'Batch : {batch_id}')
            print(f' | Loss : {total_batch_loss}')
            total_batch_loss = 0
            torch.save(model.state_dict(), '/Users/gursi/Desktop/hackathon/weights/trained_1.pt')