In [106]:
from __future__ import print_function, division
import torch
import pandas as pd
import numpy as np
import os
from skimage import io, transform

In [125]:
from torch.utils.data import Dataset

class FramesDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.frames = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
    
    def __len__(self):
        return len(self.frames)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
  
        exc_type = int(positions[22][-1])
        
        # index 0 is frame number, index 22 is exc type
        positions = self.frames.iloc[idx, :]
        positions = np.array(positions[1:22])  
        positions = positions.astype('float').reshape(-1, 21)

        if self.transform:
            positions = self.transform(positions)
        return positions, exc_type
    

In [126]:
train = FramesDataset(csv_file = 'data/train.csv', root_dir='data/')
test = FramesDataset(csv_file = 'data/test.csv', root_dir='data/')

print(len(train))
print(train[100])

trainset = torch.utils.data.DataLoader(train, batch_size=5, shuffle=True)
testset = torch.utils.data.DataLoader(test, batch_size=5, shuffle=False)

print(trainset)

1600
[-0.2237914651632309 -0.05457500740885735 -0.2526101768016815
 -0.2048950046300888 0.028485439717769626 -0.4734874963760376 0.0 0.0 0.0
 0.11020604521036148 -0.03754030168056488 -1.4930717945098877 0.0 0.0 0.0
 0.2285628467798233 -0.053275596350431435 -0.2522951662540436
 0.21870994567871094 0.029237512499094013 -0.472160816192627]
(array([[-0.22379147, -0.05457501, -0.25261018, -0.204895  ,  0.02848544,
        -0.4734875 ,  0.        ,  0.        ,  0.        ,  0.11020605,
        -0.0375403 , -1.49307179,  0.        ,  0.        ,  0.        ,
         0.22856285, -0.0532756 , -0.25229517,  0.21870995,  0.02923751,
        -0.47216082]]), 1)
<torch.utils.data.dataloader.DataLoader object at 0x1a2b354240>


In [127]:
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(21, 48)
        self.fc3 = nn.Linear(48, 2)
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc3(x)
#         return x
        return F.log_softmax(x, dim=1)   # taking log_softmax not needed, for optimization?

net = Net().double()

In [128]:
import torch.optim as optim

# algorithm for first-order gradient-based optimization, built into PyTorch.
optimizer = optim.Adam(net.parameters(), lr=0.001)  
# optimizer = optim.SGD(net.parameters(), lr=0.01)

In [129]:
for i, data in enumerate(trainset, 0):
    x, y = data
    print(data)

[-0.2548377811908722 -0.012262092903256416 -0.2289300113916397
 -0.4176445007324219 0.08188863098621367 -0.3684267997741699 0.0 0.0 0.0
 0.10589034855365753 0.015069430693984032 -1.4843846559524536 0.0 0.0 0.0
 0.2570152580738068 -0.015307197347283363 -0.22977718710899356
 0.4324711859226227 0.07524202764034271 -0.37872222065925604]
[-0.2606731355190277 0.00249108811840415 -0.2154647409915924
 -0.4079771935939789 0.13292670249938965 -0.3077040910720825 0.0 0.0 0.0
 0.1167714148759842 0.01388594601303339 -1.4960218667984009 0.0 0.0 0.0
 0.2837972939014435 0.0027664299122989178 -0.20929864048957825
 0.4514746069908142 0.13830558955669406 -0.29931288957595825]
[-0.4528589248657226 0.06479823589324951 -0.02075690217316151
 -0.2970333695411682 0.1700676828622818 -0.0424746498465538 0.0 0.0 0.0
 0.12861371040344238 -0.05038581788539887 -1.5364749431610107 0.0 0.0 0.0
 0.4606923162937164 0.04678179323673248 0.03060873225331306
 0.3035728931427002 0.1201907992362976 -0.016068905591964718]
[-0.

[tensor([[[-2.1741e-01, -4.6517e-02, -2.3876e-01, -2.4232e-01,  1.3699e-03,
          -4.8203e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.1777e-01,
          -1.5654e-02, -1.4780e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           2.5369e-01, -4.7587e-02, -2.3589e-01,  3.1633e-01, -3.2809e-04,
          -4.7866e-01]],

        [[-2.5155e-01, -5.4519e-02, -2.4046e-01, -2.1373e-01,  2.5697e-02,
          -4.6250e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.1315e-01,
          -3.3077e-02, -1.4971e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           2.5558e-01, -5.6953e-02, -2.3883e-01,  2.1924e-01,  2.3978e-02,
          -4.5949e-01]],

        [[-3.7082e-01,  1.3750e-01,  5.4848e-02, -4.1265e-01,  2.8289e-01,
           2.3738e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.5053e-01,
           7.8752e-02, -1.5503e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           3.3572e-01,  1.3748e-01,  4.8432e-02,  3.5501e-01,  2.7220e-01,
           2.4049e-01]],

        [[-3.6815e-01

 0.2865303158760071 0.050518497824668884 -0.42709359526634216]
[tensor([[[-3.6161e-01,  5.1069e-02, -1.1057e-01, -5.2882e-01,  2.1076e-01,
          -4.3448e-02,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.2980e-01,
           2.2367e-02, -1.5239e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           3.6620e-01,  5.8744e-02, -9.9959e-02,  5.2688e-01,  2.2992e-01,
          -2.0184e-02]],

        [[-2.8885e-01,  5.6154e-04, -2.4500e-01, -2.1209e-01,  1.1376e-01,
          -4.8254e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0725e-01,
          -1.1117e-01, -1.4503e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           3.2364e-01,  4.0051e-03, -2.3930e-01,  2.8412e-01,  1.2709e-01,
          -4.6992e-01]],

        [[-2.5592e-01, -2.4810e-02, -2.1484e-01, -2.1802e-01,  1.2739e-01,
          -3.7895e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.1086e-01,
          -3.2095e-02, -1.5001e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           2.7745e-01, -1.5233e-02, -2.1219e-01,  2.4953e-0

 0.28029006719589233 0.20996466279029846 0.3273237645626068]
[-0.21289104223251346 0.05037959665060043 -0.25134623050689703
 -0.12761235237121582 0.1938921809196472 -0.4227370023727417 0.0 0.0 0.0
 0.09297305345535277 -0.08925898373126984 -1.4778467416763306 0.0 0.0 0.0
 0.2117812931537628 0.0440078005194664 -0.2554553151130676
 0.14142067730426788 0.19209301471710205 -0.4249344766139984]
[tensor([[[-0.3822,  0.0493, -0.0139, -0.5285,  0.1674,  0.1193,  0.0000,
           0.0000,  0.0000,  0.1447,  0.0097, -1.5565,  0.0000,  0.0000,
           0.0000,  0.3814,  0.0401, -0.0240,  0.5085,  0.1747,  0.1051]],

        [[-0.3753,  0.0068, -0.1070, -0.2131,  0.0489, -0.2727,  0.0000,
           0.0000,  0.0000,  0.1249, -0.0121, -1.5409,  0.0000,  0.0000,
           0.0000,  0.3724, -0.0032, -0.0995,  0.2190,  0.0438, -0.2642]],

        [[-0.2207, -0.0381, -0.2560, -0.1703,  0.0600, -0.4632,  0.0000,
           0.0000,  0.0000,  0.1050, -0.0507, -1.4901,  0.0000,  0.0000,
           0.0000

[tensor([[[-3.5878e-01, -4.7746e-02, -1.5664e-01, -2.3139e-01,  7.9777e-02,
          -2.9251e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.1104e-01,
          -4.6433e-02, -1.5322e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           3.6854e-01, -6.0113e-02, -1.3739e-01,  2.4425e-01,  5.7549e-02,
          -2.7760e-01]],

        [[-3.3724e-01, -1.0287e-02, -2.2242e-01, -2.3830e-01,  1.1473e-01,
          -4.2749e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0086e-01,
          -9.7550e-02, -1.4739e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           3.4857e-01, -1.1424e-02, -2.1496e-01,  2.5528e-01,  1.0981e-01,
          -4.1766e-01]],

        [[-3.9219e-01,  3.1965e-02, -3.7456e-02, -2.0854e-01,  5.0324e-02,
          -1.7358e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.2677e-01,
          -1.3775e-02, -1.5541e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           3.9053e-01,  2.6047e-02, -2.7180e-02,  2.1246e-01,  3.8929e-02,
          -1.6939e-01]],

        [[-2.0879e-01

[-0.4485262334346771 0.030102573335170742 -0.05589067190885544
 -0.3057243227958679 0.16000714898109436 -0.08339792490005493 0.0 0.0 0.0
 0.13698089122772214 -0.032688017934560776 -1.5399038791656494 0.0 0.0 0.0
 0.4569879472255707 0.05679304152727128 -0.01970129832625389
 0.3105329871177673 0.1583673506975174 -0.07110147178173065]
[-0.3079789876937866 0.006734399124979973 -0.2290107160806656
 -0.21692512929439545 0.13669349253177646 -0.4483002126216888 0.0 0.0 0.0
 0.10870231688022614 -0.10224981606006622 -1.4713170528411863 0.0 0.0 0.0
 0.3552465438842773 0.005453757941722871 -0.21837352216243744
 0.2984582185745239 0.13957054913043976 -0.4218840003013611]
[-0.2415422797203064 0.0374438501894474 -0.24277889728546145
 -0.14599768817424774 0.1802154779434204 -0.4086587131023407 0.0 0.0 0.0
 0.0994681790471077 -0.08758202940225601 -1.48917555809021 0.0 0.0 0.0
 0.261298805475235 0.02428599447011948 -0.23906834423542025
 0.18641121685504916 0.1640009731054306 -0.4099650979042053]
[-0.402

[tensor([[[-3.7467e-01,  1.3719e-01,  4.2174e-02, -4.2853e-01,  2.8405e-01,
           2.3729e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.5563e-01,
           7.1545e-02, -1.5501e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           3.4674e-01,  1.4963e-01,  4.6988e-02,  3.8371e-01,  2.9348e-01,
           2.5615e-01]],

        [[-3.7652e-01,  8.3641e-02,  1.0724e-01, -1.8881e-01,  5.0034e-02,
           7.0023e-02,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.3243e-01,
          -2.0183e-03, -1.5710e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           3.7370e-01,  8.2400e-02,  1.1812e-01,  1.8890e-01,  4.1953e-02,
           6.3380e-02]],

        [[-2.0504e-01,  1.7066e-02, -2.7562e-01, -2.0397e-01,  1.3064e-01,
          -5.2638e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  8.4664e-02,
          -1.0817e-01, -1.4512e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           1.9782e-01,  1.2098e-02, -2.7458e-01,  2.1300e-01,  1.1611e-01,
          -5.2611e-01]],

        [[-3.2745e-01

[-0.3418197631835937 0.15171349048614502 0.07754293084144592
 -0.3429335057735443 0.2528838515281677 0.2734496295452118 0.0 0.0 0.0
 0.14793135225772858 0.06876897811889647 -1.5511343479156494 0.0 0.0 0.0
 0.3225340843200684 0.1471397876739502 0.07148072123527527
 0.3327257037162781 0.2419577240943909 0.2749395966529846]
[-0.3236950933933258 0.15833139419555664 0.11268073320388795
 -0.2921513617038727 0.2419978231191635 0.3154260516166687 0.0 0.0 0.0
 0.15076373517513275 0.06386104971170425 -1.5579892396926882 0.0 0.0 0.0
 0.314926952123642 0.1613369584083557 0.11097913980484007
 0.2975935935974121 0.2334786653518677 0.3197939991950989]
[-0.44468316435813904 0.06484654545783998 -0.04056775942444801
 -0.29358041286468506 0.18687467277050016 -0.07271440327167511 0.0 0.0 0.0
 0.12903021275997162 -0.04557614400982857 -1.5383214950561523 0.0 0.0 0.0
 0.460034042596817 0.0437520295381546 -0.0038457326591014858
 0.31321388483047485 0.13605163991451266 -0.05838118866086006]
[-0.320887506008148

[tensor([[[-0.3606,  0.0777, -0.0865, -0.5010,  0.2282,  0.0235,  0.0000,
           0.0000,  0.0000,  0.1374,  0.0173, -1.5376,  0.0000,  0.0000,
           0.0000,  0.3672,  0.0773, -0.0818,  0.5013,  0.2270,  0.0403]],

        [[-0.3896,  0.0984, -0.0447, -0.5416,  0.2810,  0.0162,  0.0000,
           0.0000,  0.0000,  0.1383,  0.0237, -1.5519,  0.0000,  0.0000,
           0.0000,  0.3813,  0.0923, -0.0418,  0.5386,  0.2686,  0.0273]],

        [[-0.3827,  0.0023, -0.1744, -0.2687,  0.1445, -0.2981,  0.0000,
           0.0000,  0.0000,  0.1256, -0.0821, -1.4997,  0.0000,  0.0000,
           0.0000,  0.4122,  0.0192, -0.1601,  0.2986,  0.1694, -0.2904]],

        [[-0.3684,  0.1503,  0.0440, -0.3956,  0.2949,  0.2422,  0.0000,
           0.0000,  0.0000,  0.1596,  0.0747, -1.5472,  0.0000,  0.0000,
           0.0000,  0.3491,  0.1552,  0.0484,  0.3721,  0.2954,  0.2519]],

        [[-0.2190, -0.0425, -0.2559, -0.1700,  0.0490, -0.4656,  0.0000,
           0.0000,  0.0000,  0.1057, -

[tensor([[[-0.3872,  0.0976,  0.0885, -0.1829,  0.0492,  0.0548,  0.0000,
           0.0000,  0.0000,  0.1382, -0.0092, -1.5670,  0.0000,  0.0000,
           0.0000,  0.3823,  0.0911,  0.0903,  0.1894,  0.0466,  0.0568]],

        [[-0.3774,  0.0604, -0.0772, -0.2046,  0.1337, -0.1651,  0.0000,
           0.0000,  0.0000,  0.1261, -0.0344, -1.5464,  0.0000,  0.0000,
           0.0000,  0.3758,  0.0549, -0.0874,  0.2194,  0.1413, -0.1717]],

        [[-0.2587, -0.0379, -0.2216, -0.2820,  0.0777, -0.4197,  0.0000,
           0.0000,  0.0000,  0.1188, -0.0257, -1.5064,  0.0000,  0.0000,
           0.0000,  0.2657, -0.0306, -0.2250,  0.2825,  0.0964, -0.4210]],

        [[-0.3603,  0.1279,  0.0203, -0.4535,  0.2893,  0.1695,  0.0000,
           0.0000,  0.0000,  0.1402,  0.0251, -1.5489,  0.0000,  0.0000,
           0.0000,  0.3595,  0.1279,  0.0284,  0.4616,  0.2803,  0.1777]],

        [[-0.3767,  0.0099, -0.0970, -0.2316,  0.1212, -0.1979,  0.0000,
           0.0000,  0.0000,  0.1185, -

[tensor([[[-3.3842e-01, -7.9441e-03, -2.2222e-01, -2.3835e-01,  1.0489e-01,
          -4.2003e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.2100e-01,
          -9.4695e-02, -1.4785e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           3.7487e-01,  4.3310e-03, -2.0747e-01,  3.0654e-01,  1.3369e-01,
          -3.9683e-01]],

        [[-2.6931e-01, -3.3362e-03, -2.1351e-01, -4.3306e-01,  1.1095e-01,
          -3.1105e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.1818e-01,
           7.1471e-04, -1.4948e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           2.9653e-01,  2.7074e-03, -2.0539e-01,  4.8641e-01,  1.1987e-01,
          -2.9465e-01]],

        [[-3.2360e-01,  2.1305e-02, -1.6928e-01, -4.1815e-01,  1.9374e-01,
          -2.8634e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.2511e-01,
          -3.5881e-02, -1.5267e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           3.2394e-01,  3.1490e-02, -1.6717e-01,  4.1185e-01,  2.1814e-01,
          -2.5353e-01]],

        [[-2.3006e-01

[-0.351289838552475 -0.023851295933127403 -0.21273961663246155
 -0.2377858459949493 0.1075783222913742 -0.3834273815155029 0.0 0.0 0.0
 0.12515641748905182 -0.07660607993602753 -1.4869742393493652 0.0 0.0 0.0
 0.3838324546813965 -0.008733988739550114 -0.19965340197086331
 0.2932157814502716 0.12128565460443495 -0.3736099004745483]
[-0.21496723592281344 0.042502544820308685 -0.27612197399139404
 -0.2014097422361374 0.19155338406562805 -0.4964124858379364 0.0 0.0 0.0
 0.08047056198120117 -0.10981065779924393 -1.449036955833435 0.0 0.0 0.0
 0.2046712785959244 0.05851102247834205 -0.27220311760902405
 0.16943374276161194 0.19918537139892575 -0.4861096441745758]
[-0.34132614731788635 0.1456765979528427 0.053455058485269547
 -0.37662890553474426 0.2842933535575867 0.2324106991291046 0.0 0.0 0.0
 0.14388428628444672 0.04029927402734757 -1.55071222782135 0.0 0.0 0.0
 0.3388327360153198 0.1467597782611847 0.05945868790149689
 0.3867438435554504 0.2808995544910431 0.2391139566898346]
[tensor([[[

 0.18334408104419708 0.040110655128955834 0.05895965173840523]
[-0.3720894753932953 0.09665872156620026 0.10468392074108124
 -0.17920058965682986 0.047839924693107605 0.07146614044904709 0.0 0.0 0.0
 0.13220597803592682 0.0018380224937573075 -1.5711418390274048 0.0 0.0 0.0
 0.371897965669632 0.09201131016016006 0.11734580248594285
 0.18330737948417666 0.03790539875626564 0.06212344020605087]
[-0.32955148816108704 -0.03889772668480873 -0.13892343640327454
 -0.4998756051063538 0.018044721335172653 -0.2376803159713745 0.0 0.0 0.0
 0.10480270534753801 -0.01854604110121727 -1.5178009271621704 0.0 0.0 0.0
 0.31257694959640503 -0.04825665056705475 -0.13635724782943726
 0.4815692901611328 0.004932355135679245 -0.2195950746536255]
[-0.22778667509555814 -0.028815802186727524 -0.23804925382137296
 -0.3274761140346527 0.04208458960056305 -0.4383261203765869 0.0 0.0 0.0
 0.11526311188936235 0.023931039497256282 -1.4785141944885254 0.0 0.0 0.0
 0.2594384849071503 -0.03308502584695816 -0.232692003250

[tensor([[[-0.4472,  0.0521, -0.0612, -0.2943,  0.1723, -0.1043,  0.0000,
           0.0000,  0.0000,  0.1316, -0.0459, -1.5308,  0.0000,  0.0000,
           0.0000,  0.4587,  0.0562, -0.0180,  0.2978,  0.1516, -0.0762]],

        [[-0.4072,  0.0597,  0.0113, -0.2248,  0.1122, -0.0585,  0.0000,
           0.0000,  0.0000,  0.1370, -0.0372, -1.5603,  0.0000,  0.0000,
           0.0000,  0.3987,  0.0523,  0.0220,  0.2255,  0.1028, -0.0614]],

        [[-0.3612,  0.0516, -0.0999, -0.5476,  0.2014, -0.0472,  0.0000,
           0.0000,  0.0000,  0.1243,  0.0189, -1.5266,  0.0000,  0.0000,
           0.0000,  0.3525,  0.0501, -0.1035,  0.5260,  0.1947, -0.0398]],

        [[-0.4471,  0.0496, -0.0773, -0.3031,  0.1895, -0.1582,  0.0000,
           0.0000,  0.0000,  0.1372, -0.0538, -1.5283,  0.0000,  0.0000,
           0.0000,  0.4573,  0.0403, -0.0507,  0.3313,  0.1478, -0.1516]],

        [[-0.2131,  0.0313, -0.2749, -0.2047,  0.1720, -0.5035,  0.0000,
           0.0000,  0.0000,  0.0794, -

In [130]:
EPOCHS = 3

for epoch in range(EPOCHS):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainset, 0):
        # get the inputs; data is dictionary
        x, y = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        output = net(x)
        print(output)
        loss = F.nll_loss(output, y)
        loss.backward()
        optimizer.step()

        # print statistics
#         running_loss += loss.item()
#         if i % 2000 == 1999:    # print every 2000 mini-batches
#             print('[%d, %5d] loss: %.3f' %
#                   (epoch + 1, i + 1, running_loss / 2000))
#             running_loss = 0.0

print('Finished Training')

[-0.20603977143764496 0.013987284153699877 -0.2761968672275543
 -0.19866523146629333 0.13674086332321167 -0.5227307081222534 0.0 0.0 0.0
 0.0825752541422844 -0.11802864074707033 -1.454328536987305 0.0 0.0 0.0
 0.19852524995803836 0.01957390084862709 -0.2750842869281769
 0.2041141390800476 0.13360409438610074 -0.5214259028434753]
[-0.3530977070331573 -0.00963177066296339 -0.15114760398864746
 -0.20599889755249026 0.1036943793296814 -0.2879768908023834 0.0 0.0 0.0
 0.11402542144060135 -0.07706739008426666 -1.532421588897705 0.0 0.0 0.0
 0.3625690340995789 -0.014637320302426815 -0.14080223441123962
 0.22001712024211886 0.09517720341682434 -0.2726815044879913]
[-0.3150325119495392 0.01238398440182209 -0.18654274940490725
 -0.4981613755226135 0.1350822150707245 -0.2461068332195282 0.0 0.0 0.0
 0.11472585797309875 0.011506356298923492 -1.505521535873413 0.0 0.0 0.0
 0.3193698525428772 0.011468361131846905 -0.18883419036865234
 0.510678768157959 0.13513097167015076 -0.24938809871673584]
[-0.4

ValueError: Expected target size (5, 2), got torch.Size([5])