In [3]:
import numpy as np
import h5py
import torch
import torch.utils.data as data_utils
from torch.utils.data.dataset import random_split
import numpy as np
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim

  from ._conv import register_converters as _register_converters


# Helper Clases / Functions

In [4]:
def Load_Data(num):
    if (num == -1): # All data
        X_all = []
        y_all = []
        for i in range(8):
            file_path = './../project_datasets/A0' + str(i+1) + 'T_slice.mat'
            data = h5py.File(file_path, 'r')
            X = np.copy(data['image'])
            y = np.copy(data['type'])
            X = X[:, 0:23, :]
            X_all.append(X)
            y = y[0,0:X.shape[0]:1]
            y_all.append(y)
        A, N, E, T = np.shape(X_all)
        X_all = np.reshape(X_all, (A*N, E, T))
        y_all = np.reshape(y_all, (-1))
        y_all = y_all - 769
        ## Remove NAN
        index_Nan = []
        for i in range(A*N):
            for j in range(E):
                if (any(np.isnan(X_all[i,j])) == True):
                    index_Nan.append(i)
        index_Nan = list(set(index_Nan))
        X_all = np.delete(X_all, index_Nan, axis=0)
        y_all = np.delete(y_all, index_Nan)
        return (X_all, y_all)
    else:
        file_path = './../project_datasets/A0' + str(num) + 'T_slice.mat'
        data = h5py.File(file_path, 'r')
        X = np.copy(data['image'])
        y = np.copy(data['type'])
        X = X[:, 0:23, :]
        y = y[0,0:X.shape[0]:1]
        y = y - 769
         ## Remove NAN
        N, E, T = np.shape(X)
        index_Nan = []
        for i in range(N):
            for j in range(E):
                if (any(np.isnan(X[i,j])) == True):
                    index_Nan.append(i)
        index_Nan = list(set(index_Nan))
        X = np.delete(X, index_Nan, axis=0)
        y = np.delete(y, index_Nan)
        return (X, y)

# Load Data

In [5]:
X, y = Load_Data(-1) # -1 to load all datas
N, E, T = np.shape(X)
print (np.shape(X))

(2280, 23, 1000)


# Make DataLoaders

In [51]:
bs_train = 200
bs_val = 100
bs_test = 100
data = data_utils.TensorDataset(torch.Tensor(X), torch.Tensor(y))
dset = {}
dataloaders = {}
dset['train'], dset['val'], dset['test'] = random_split(data, [N-bs_val-bs_test, bs_val, bs_test])
dataloaders['train'] = data_utils.DataLoader(dset['train'], batch_size=bs_train, shuffle=True, num_workers=1)
dataloaders['val'] = data_utils.DataLoader(dset['val'], batch_size=bs_val, shuffle=True, num_workers=1)
dataloaders['test'] = data_utils.DataLoader(dset['test'], batch_size=bs_test, shuffle=True, num_workers=1)

# Define Model

In [52]:
class myConv(nn.Module):
    def __init__(self, num_class):
        super(myConv, self).__init__()
        self.conv_temp = nn.Conv2d(1,40,tuple([1,25]))
        self.conv_elec = nn.Conv3d(1,40,tuple([40, 23, 1]))
        self.pool = nn.AvgPool2d(tuple([1,47]))
        self.classifier = nn.Linear(40*20, num_class)
    def forward(self, x):
        N, H, W = x.size()
        x.unsqueeze_(1)
        out_conv_temp = self.conv_temp(x)
        out_conv_temp = out_conv_temp.unsqueeze_(1)
        out_conv_elec = self.conv_elec(out_conv_temp)
        out_conv_elec = torch.squeeze(out_conv_elec) # shape: [N, 40, 976]
        out_conv_elec.unsqueeze_(1)
        out_pool = self.pool(out_conv_elec) 
        out_pool = torch.squeeze(out_pool) # shape: [N, 40, 20]
        out_pool = out_pool.view(N, -1) # shape: [N, 800]
        out = self.classifier(out_pool)
        return out
    def check_accuracy(self, dataloader):
        total_correct = 0
        total_label = 0
        for i_batch, sample_batched in enumerate(dataloader):
            X_sample, y_sample = sample_batched
            X_sample, y_sample = Variable(X_sample), Variable(y_sample)
            out = self.forward(X_sample.cuda())
            _, pred = torch.max(out, 1)
            num_correct = np.sum(pred.data.cpu().numpy() == y_sample.data.cpu().numpy())
            total_correct += num_correct
            total_label += len(pred)
        return  total_correct / total_label

In [65]:
dtype = torch.cuda.FloatTensor
num_classes = 4
num_epoches = 100
model = myConv(num_classes)
model.type(dtype)
loss_fn = nn.CrossEntropyLoss().type(dtype)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Conv Training

In [66]:
best_acc = 0.67

In [67]:
for epoch in range(num_epoches):
    for i, data in enumerate(dataloaders['train'], 0):
        X_train, y_train = data
        # Wrap them in Variable
        X_train, y_train = Variable(X_train), Variable(y_train)
        # forward + backward + optimize
        out = model(X_train.cuda())
        # print (out)
        loss = loss_fn(out, y_train.long().cuda())
        print('(%d batch) loss: %f' % (i, loss))
        # zero the parameter gradients
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    train_acc = model.check_accuracy(dataloaders['train'])
    val_acc = model.check_accuracy(dataloaders['val'])
    print('(Epoch %d / %d) train_acc: %f; val_acc: %f' % (epoch+1, num_epoches, train_acc, val_acc))
    if (val_acc > best_acc):
        best_acc = val_acc
        torch.save(model, 'best_SHALLOW_CONV.pt')

(0 batch) loss: 1.681994
(1 batch) loss: 1.741327
(2 batch) loss: 1.476816
(3 batch) loss: 1.490850
(4 batch) loss: 1.450372
(5 batch) loss: 1.467518
(6 batch) loss: 1.280683
(7 batch) loss: 1.316637
(8 batch) loss: 1.334978
(9 batch) loss: 1.406701
(10 batch) loss: 1.278842
(Epoch 1 / 100) train_acc: 0.481250; val_acc: 0.420000
(0 batch) loss: 1.216806
(1 batch) loss: 1.124170
(2 batch) loss: 1.207967
(3 batch) loss: 1.136848
(4 batch) loss: 1.125348
(5 batch) loss: 1.079948
(6 batch) loss: 1.139096
(7 batch) loss: 1.120370
(8 batch) loss: 1.204879
(9 batch) loss: 1.114526
(10 batch) loss: 1.051921
(Epoch 2 / 100) train_acc: 0.573077; val_acc: 0.540000
(0 batch) loss: 1.112021
(1 batch) loss: 1.051784
(2 batch) loss: 0.960751
(3 batch) loss: 0.957773
(4 batch) loss: 0.992380
(5 batch) loss: 0.987003
(6 batch) loss: 0.966693
(7 batch) loss: 1.061503
(8 batch) loss: 0.951873
(9 batch) loss: 1.016487
(10 batch) loss: 0.934715
(Epoch 3 / 100) train_acc: 0.615385; val_acc: 0.620000
(0 batc

  "type " + obj.__name__ + ". It won't be checked "


(0 batch) loss: 0.604576
(1 batch) loss: 0.617040
(2 batch) loss: 0.462088
(3 batch) loss: 0.565822
(4 batch) loss: 0.571278
(5 batch) loss: 0.563454
(6 batch) loss: 0.589214
(7 batch) loss: 0.528381
(8 batch) loss: 0.599452
(9 batch) loss: 0.616774
(10 batch) loss: 0.546906
(Epoch 20 / 100) train_acc: 0.812019; val_acc: 0.640000
(0 batch) loss: 0.532901
(1 batch) loss: 0.586928
(2 batch) loss: 0.603047
(3 batch) loss: 0.524598
(4 batch) loss: 0.540848
(5 batch) loss: 0.670017
(6 batch) loss: 0.502278
(7 batch) loss: 0.429817
(8 batch) loss: 0.628474
(9 batch) loss: 0.547757
(10 batch) loss: 0.563726
(Epoch 21 / 100) train_acc: 0.816346; val_acc: 0.650000
(0 batch) loss: 0.473851
(1 batch) loss: 0.457771
(2 batch) loss: 0.520813
(3 batch) loss: 0.570726
(4 batch) loss: 0.617338
(5 batch) loss: 0.618141
(6 batch) loss: 0.559964
(7 batch) loss: 0.502003
(8 batch) loss: 0.582777
(9 batch) loss: 0.533281
(10 batch) loss: 0.476705
(Epoch 22 / 100) train_acc: 0.806731; val_acc: 0.590000
(0 b

(10 batch) loss: 0.483190
(Epoch 44 / 100) train_acc: 0.859615; val_acc: 0.600000
(0 batch) loss: 0.372405
(1 batch) loss: 0.344371
(2 batch) loss: 0.402747
(3 batch) loss: 0.447088
(4 batch) loss: 0.490947
(5 batch) loss: 0.308106
(6 batch) loss: 0.451739
(7 batch) loss: 0.492055
(8 batch) loss: 0.385574
(9 batch) loss: 0.437264
(10 batch) loss: 0.425404
(Epoch 45 / 100) train_acc: 0.864423; val_acc: 0.640000
(0 batch) loss: 0.382344
(1 batch) loss: 0.339638
(2 batch) loss: 0.452390
(3 batch) loss: 0.346529
(4 batch) loss: 0.355850
(5 batch) loss: 0.349666
(6 batch) loss: 0.335504
(7 batch) loss: 0.386716
(8 batch) loss: 0.433690
(9 batch) loss: 0.358216
(10 batch) loss: 0.480096
(Epoch 46 / 100) train_acc: 0.893269; val_acc: 0.620000
(0 batch) loss: 0.370888
(1 batch) loss: 0.291829
(2 batch) loss: 0.354600
(3 batch) loss: 0.314838
(4 batch) loss: 0.345580
(5 batch) loss: 0.391930
(6 batch) loss: 0.408138
(7 batch) loss: 0.413015
(8 batch) loss: 0.267310
(9 batch) loss: 0.338601
(10 

(6 batch) loss: 0.370798
(7 batch) loss: 0.314949
(8 batch) loss: 0.299147
(9 batch) loss: 0.320156
(10 batch) loss: 0.250713
(Epoch 69 / 100) train_acc: 0.897596; val_acc: 0.500000
(0 batch) loss: 0.290989
(1 batch) loss: 0.276039
(2 batch) loss: 0.362930
(3 batch) loss: 0.253117
(4 batch) loss: 0.232330
(5 batch) loss: 0.253719
(6 batch) loss: 0.280766
(7 batch) loss: 0.235467
(8 batch) loss: 0.212853
(9 batch) loss: 0.372074
(10 batch) loss: 0.206477
(Epoch 70 / 100) train_acc: 0.919712; val_acc: 0.580000
(0 batch) loss: 0.241923
(1 batch) loss: 0.255481
(2 batch) loss: 0.276816
(3 batch) loss: 0.245853
(4 batch) loss: 0.263462
(5 batch) loss: 0.320013
(6 batch) loss: 0.231150
(7 batch) loss: 0.277787
(8 batch) loss: 0.310023
(9 batch) loss: 0.279387
(10 batch) loss: 0.240825
(Epoch 71 / 100) train_acc: 0.937981; val_acc: 0.610000
(0 batch) loss: 0.193157
(1 batch) loss: 0.208396
(2 batch) loss: 0.239227
(3 batch) loss: 0.198579
(4 batch) loss: 0.199297
(5 batch) loss: 0.267776
(6 b

(2 batch) loss: 0.183424
(3 batch) loss: 0.165531
(4 batch) loss: 0.164185
(5 batch) loss: 0.117948
(6 batch) loss: 0.128476
(7 batch) loss: 0.146122
(8 batch) loss: 0.115939
(9 batch) loss: 0.088024
(10 batch) loss: 0.074454
(Epoch 94 / 100) train_acc: 0.961538; val_acc: 0.630000
(0 batch) loss: 0.201209
(1 batch) loss: 0.112855
(2 batch) loss: 0.086731
(3 batch) loss: 0.092480
(4 batch) loss: 0.112610
(5 batch) loss: 0.118108
(6 batch) loss: 0.127120
(7 batch) loss: 0.087514
(8 batch) loss: 0.097384
(9 batch) loss: 0.137061
(10 batch) loss: 0.104594
(Epoch 95 / 100) train_acc: 0.978365; val_acc: 0.550000
(0 batch) loss: 0.088828
(1 batch) loss: 0.111047
(2 batch) loss: 0.151016
(3 batch) loss: 0.083606
(4 batch) loss: 0.111520
(5 batch) loss: 0.092865
(6 batch) loss: 0.101622
(7 batch) loss: 0.126045
(8 batch) loss: 0.104132
(9 batch) loss: 0.116604
(10 batch) loss: 0.209511
(Epoch 96 / 100) train_acc: 0.970673; val_acc: 0.550000
(0 batch) loss: 0.162550
(1 batch) loss: 0.091008
(2 b

# Best Model Test

In [60]:
best_model = torch.load('best_SHALLOW_CONV.pt')
print (best_model.check_accuracy(dataloaders['test']))

0.65
