In [189]:
import pandas as pd
import pdb
from pathlib import Path
from fastai.basics import *
path = Path("UCRArchive_2018")

In [4]:
list((path/"car").iterdir())

[WindowsPath('UCRArchive_2018/car/Car_TEST.tsv'),
 WindowsPath('UCRArchive_2018/car/Car_TRAIN.tsv'),
 WindowsPath('UCRArchive_2018/car/desktop.ini'),
 WindowsPath('UCRArchive_2018/car/README.md')]

In [190]:
class TSCDS(torch.utils.data.Dataset):
    def __init__(self,X,Y):
        self.X = X
        self.Y = Y
        self.items = [1,2,3]
    
    def __len__(self): return len(self.X)
    
    def __getitem__(self,idx):
        return np.expand_dims(self.X[idx],axis=0), self.Y[idx] - 1

In [206]:
class ConvBlock(torch.nn.Module):
    def __init__(self,inputC,outputC,kernelSize):
        super().__init__()
        self.conv = torch.nn.Conv1d(inputC,outputC,kernelSize)
        self.drop = torch.nn.Dropout()
        self.bn = torch.nn.BatchNorm1d(outputC)
        
    def forward(self,x):
        return self.drop(self.bn(torch.relu(self.conv(x))))

class TSCModel(torch.nn.Module):
    def __init__(self,numClasses):
        super().__init__()
        self.blocks = torch.nn.ModuleList([ConvBlock(i,o,k) for i,o,k in [(1,128,8),(128,256,5),(256,128,3)]])
        self.avgpool = torch.nn.AdaptiveAvgPool1d(1)
        self.out = torch.nn.Linear(128,numClasses)
        
    def forward(self,x):
        x = x.float()
        for b in self.blocks: x = b(x)
        x = self.avgpool(x)
        return self.out(x.squeeze(-1))

In [207]:
def trainNetwork(d,weights=None,epochs=20,lr=5e-2):
    trainDF = pd.read_csv(path/f"{d}/{d}_TRAIN.tsv",sep="\t",header=None)
    testDF = pd.read_csv(path/f"{d}/{d}_TEST.tsv",sep="\t",header=None)
    cat = trainDF.iloc[:,0].astype("category")
    trainDF.iloc[:,0] = cat.cat.codes
    testDF.iloc[:,0] = pd.Categorical(testDF.iloc[:,0],categories=cat.cat.categories).codes
    nClasses = len(cat.cat.categories)
    
    trainDS = TSCDS(trainDF.iloc[:,1:].values,trainDF.iloc[:,0].values)
    testDS = TSCDS(testDF.iloc[:,1:].values,testDF.iloc[:,0].values)
    model = TSCModel(nClasses)
    learn = Learner(data,model,loss_func=torch.nn.functional.cross_entropy,metrics=[accuracy])
    if weights:
        learn.model.blocks.load_state_dict(weights)
    
    learn.fit_one_cycle(epochs,lr)
    return learn.model.blocks.state_dict()

In [193]:
weights = None
for p in path.iterdir():
    if p.is_dir():
        weights = trainNetwork(p.name,weights)

epoch,train_loss,valid_loss,accuracy
1,1.749991,1.465224,0.233333
2,1.574655,1.504321,0.216667
3,1.503467,2.492872,0.233333
4,1.515240,1.740129,0.233333
5,1.459075,1.433599,0.333333
6,1.467439,1.397818,0.350000
7,1.421116,2.365916,0.233333
8,1.423876,1.537795,0.283333
9,1.423254,1.377244,0.316667


KeyboardInterrupt: 

In [208]:
w = trainNetwork("Worms",lr=1e-2,epochs=100)

epoch,train_loss,valid_loss,accuracy
1,1.615551,1.568383,0.416667
2,1.589636,1.557210,0.333333
3,1.563373,1.455474,0.416667
4,1.526425,1.336305,0.500000
5,1.481152,1.226602,0.633333
6,1.417295,1.329268,0.566667
7,1.360600,1.359261,0.300000
8,1.303365,1.243345,0.516667
9,1.257618,1.069968,0.616667
10,1.208771,1.193387,0.483333
11,1.233717,1.366392,0.450000
12,1.249702,1.334387,0.383333
13,1.242248,1.270352,0.400000
14,1.198372,1.997015,0.233333
15,1.153669,1.050262,0.533333
16,1.093696,1.032445,0.600000
17,1.187007,0.979278,0.550000
18,1.216366,2.377897,0.316667
19,1.215010,1.346071,0.416667
20,1.232472,1.590836,0.400000
21,1.162961,1.145141,0.466667
22,1.188477,1.642459,0.383333
23,1.196061,1.557037,0.366667
24,1.212500,1.689725,0.266667
25,1.187407,1.083794,0.466667
26,1.193559,1.323884,0.350000
27,1.174371,1.547679,0.450000
28,1.146998,1.143000,0.583333
29,1.166678,1.306656,0.550000
30,1.178457,1.427772,0.350000
31,1.190176,2.340953,0.233333
32,1.220153,1.795185,0.350000
33,1.198130,2.016839,0.300000
34,1.241860,1.599926,0.333333
35,1.226187,1.957110,0.250000
36,1.238088,1.962305,0.416667
37,1.232003,1.289822,0.333333
38,1.169063,1.515915,0.400000
39,1.173839,1.825251,0.366667
40,1.192527,1.347082,0.433333
41,1.167190,1.448692,0.416667
42,1.201205,1.225757,0.533333
43,1.155067,1.927499,0.450000
44,1.110300,2.206026,0.416667
45,1.105171,1.801435,0.200000
46,1.145720,1.051880,0.466667
47,1.155632,1.811285,0.250000
48,1.096927,1.194518,0.533333
49,1.088593,3.460953,0.233333
50,1.145383,1.069075,0.500000
51,1.175519,1.337830,0.350000
52,1.156374,0.876681,0.650000
53,1.144525,1.171698,0.400000
54,1.070745,0.938409,0.616667
55,1.077085,1.266402,0.400000
56,1.114236,2.326062,0.250000
57,1.004690,1.098688,0.550000
58,1.073233,0.992366,0.633333
59,1.020281,1.270277,0.533333
60,1.104294,1.111963,0.600000
61,1.097089,1.055930,0.550000
62,1.039391,3.313999,0.250000
63,1.007846,1.058527,0.650000
64,0.999835,0.991489,0.650000
65,0.997179,0.927160,0.650000
66,0.958760,0.942371,0.650000
67,0.887612,1.012384,0.583333
68,0.905010,1.205708,0.600000
69,0.907135,1.136248,0.600000
70,1.022429,1.117800,0.666667
71,0.996183,0.794599,0.700000
72,1.072672,0.799433,0.700000
73,1.041731,1.143423,0.550000
74,1.045072,0.945977,0.683333
75,0.959655,1.152286,0.633333
76,0.962949,1.005037,0.633333
77,1.013958,0.875604,0.666667
78,0.917848,0.861264,0.666667
79,0.875393,0.847685,0.683333
80,0.895439,0.971502,0.683333
81,0.877830,0.855733,0.650000
82,0.870198,0.906322,0.633333
83,0.967336,0.906497,0.650000
84,0.986497,0.987195,0.700000
85,0.929289,0.904962,0.700000
86,0.938599,0.811949,0.666667
87,0.960672,0.775793,0.716667
88,0.897664,0.776433,0.666667
89,0.907984,0.852284,0.683333
90,0.852352,0.789371,0.700000
91,0.828371,0.870477,0.683333
92,0.812102,0.840198,0.683333
93,0.866907,0.789745,0.700000
94,0.803867,0.828498,0.683333
95,0.829514,0.794899,0.700000
96,0.827412,0.815379,0.700000
97,0.812185,0.778006,0.716667
98,0.834942,0.802566,0.683333
99,0.812494,0.814997,0.683333
100,0.771604,0.857073,0.666667


In [209]:
trainNetwork("Wine",w,lr=1e-2,epochs=100)

epoch,train_loss,valid_loss,accuracy
1,-4.347342,0.445190,0.450000
2,-3.736607,0.411819,0.466667
3,-3.558654,0.379419,0.466667
4,-3.483701,0.351606,0.466667
5,-3.455778,0.337507,0.466667
6,-1.643021,0.338031,0.466667
7,-0.661304,0.337008,0.466667
8,-0.193588,0.339194,0.466667
9,0.099940,0.340972,0.466667
10,0.171600,0.333909,0.466667
11,0.307155,0.362725,0.433333
12,0.267742,0.382002,0.450000
13,0.322468,0.421524,0.466667
14,-1.346122,0.333671,0.466667
15,-0.556964,0.327216,0.450000
16,-1.860293,0.350831,0.466667
17,-0.831481,0.381810,0.466667
18,-2.017437,0.372594,0.466667
19,-0.938705,0.562532,0.433333
20,-2.118979,0.392646,0.466667
21,-0.926058,0.385888,0.433333
22,-0.318732,0.355633,0.450000
23,0.003097,1.368310,0.233333
24,-1.571072,0.515826,0.300000
25,-0.721164,0.442058,0.450000
26,-0.183517,0.701708,0.250000
27,0.074180,0.371317,0.433333
28,0.190623,0.443703,0.433333
29,-1.533793,0.397362,0.416667
30,-0.598887,0.579192,0.283333
31,-0.111302,0.417758,0.383333
32,0.121676,0.351905,0.466667
33,0.240151,0.386255,0.450000
34,0.314556,0.337116,0.450000
35,0.363286,0.438416,0.416667
36,0.327027,0.343511,0.433333
37,0.345198,0.987325,0.250000
38,0.384464,0.514152,0.333333
39,0.462066,0.364631,0.466667
40,0.458966,0.347085,0.450000
41,0.410229,0.310056,0.450000
42,0.385773,0.397065,0.450000
43,0.412908,0.337033,0.466667
44,0.320292,0.824042,0.233333
45,0.159184,0.539520,0.350000
46,-1.610795,0.438967,0.433333
47,-2.487837,0.352797,0.450000
48,-1.140981,0.468939,0.433333
49,-0.431922,0.417790,0.416667
50,-1.877508,0.701921,0.333333
51,-0.805799,0.375699,0.450000
52,-0.220637,0.334819,0.450000
53,0.062855,0.527777,0.316667
54,0.238301,0.396722,0.450000
55,0.343630,0.317985,0.450000
56,0.371981,0.908990,0.250000
57,-1.439195,0.379508,0.466667
58,-2.430872,0.405341,0.433333
59,-1.166211,0.349951,0.466667
60,-2.228200,0.340686,0.450000
61,-1.070416,0.402688,0.450000
62,-0.422184,0.357057,0.466667
63,-0.045018,0.341351,0.466667
64,-1.575618,0.360562,0.450000
65,-0.683042,0.375977,0.466667
66,-0.187051,0.340642,0.466667
67,-1.749934,0.437363,0.466667
68,-0.752392,0.356093,0.466667
69,-2.070403,0.358591,0.466667
70,-2.768805,0.385902,0.450000
71,-1.346030,0.412008,0.466667
72,-2.389117,0.381857,0.466667
73,-1.130042,0.375097,0.466667
74,-2.313262,0.404908,0.466667
75,-2.901753,0.350766,0.466667
76,-3.165075,0.386667,0.466667
77,-3.382471,0.424196,0.466667
78,-3.509089,0.411398,0.466667
79,-3.547938,0.410659,0.466667
80,-3.630021,0.362847,0.466667
81,-1.801700,0.424970,0.466667
82,-2.584016,0.452899,0.450000
83,-1.265378,0.466191,0.433333
84,-0.508414,0.401867,0.466667
85,-0.133157,0.395072,0.466667
86,-1.697159,0.419596,0.466667
87,-0.765211,0.421085,0.466667
88,-0.297720,0.424526,0.466667
89,0.025865,0.420502,0.466667
90,0.206796,0.414281,0.466667
91,0.303267,0.436358,0.466667
92,-1.460511,0.450463,0.450000
93,-4.237714,0.447838,0.466667
94,-2.113509,0.430008,0.466667
95,-0.993820,0.430656,0.466667
96,-2.180183,0.418926,0.466667
97,-1.074332,0.435356,0.466667
98,-2.227536,0.443871,0.450000
99,-2.905497,0.462517,0.450000
100,-1.384670,0.428797,0.466667


OrderedDict([('0.conv.weight',
              tensor([[[ 0.1073, -0.0872, -0.0879,  ..., -0.1043, -0.0412, -0.0535]],
              
                      [[-0.2463, -0.0052,  0.2100,  ...,  0.0963, -0.1249, -0.1503]],
              
                      [[ 0.2216,  0.1549,  0.1031,  ..., -0.0185,  0.0382, -0.1180]],
              
                      ...,
              
                      [[ 0.5980,  0.6302,  0.6335,  ...,  0.1199,  0.0938,  0.0895]],
              
                      [[-0.0889,  0.2088, -0.0529,  ..., -0.1337, -0.1374,  0.0765]],
              
                      [[ 0.5792,  0.6602,  0.2864,  ...,  0.0325, -0.3880, -0.7142]]],
                     device='cuda:0')),
             ('0.conv.bias',
              tensor([-0.3000, -0.1017, -0.3986, -0.0902, -0.3298, -0.1704,  0.4436, -0.2207,
                       0.2127,  0.3475,  0.1311, -0.2273,  0.4033, -0.0943,  0.0187, -0.1235,
                      -0.2214, -0.1711, -0.5492, -0.6197,  0.8057, -0.4565, -0