In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import skorch
import torchvision.datasets as dset
import torchvision.transforms as T

from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import cross_validate

import numpy as np

In [13]:
USE_GPU = True

dtype = torch.float32 # we will be using float throughout this tutorial

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

# Constant to control how frequently we print train loss
print_every = 100

print('using device:', device)

using device: cpu


In [14]:
#load the data
x = np.load("X.npy")
y = np.load("y.npy")

x_full = torch.from_numpy(x)
y_full = torch.from_numpy(y)
x_full = x_full.to(device=device, dtype=dtype)
y_full = y_full.to(device=device, dtype=torch.long)

In [25]:
#split the data
x = x_full[30:,:,:,:]
y = y_full[30:]

print(x.shape)
print(y.shape)

x_Test = x_full[0:29,:,:,:]
y_Test = y_full[0:29]

print(x_Test.shape)
print(y_Test.shape)

torch.Size([17, 512, 512, 62])
torch.Size([17, 1])
torch.Size([29, 512, 512, 62])
torch.Size([29, 1])


In [16]:
def flatten(x):
    N = x.shape[0] # read in N, C, H, W
    return x.view(N, -1)  # "flatten" the C * H * W values into a single vector per image

class Flatten(nn.Module):
    def forward(self, x):
        return flatten(x)


In [17]:
channel_1 = 32
channel_2 = 64
channel_3 = 32


learning_rate = 2.5e-3
model = None
optimizer = None

in_channel = 62
num_classes = 3

model = nn.Sequential(
    nn.Conv2d(in_channel, channel_1, 5, padding=2),
    nn.BatchNorm2d(channel_1),
    nn.ReLU(),
    nn.Conv2d(channel_1, channel_2, 3, padding=1),
    nn.BatchNorm2d(channel_2),
    nn.ReLU(),
    nn.MaxPool2d(2),
    Flatten(),
    nn.Linear((32*512*512)/2, 100),
    nn.Linear(100, num_classes)
)
model = model.to(device=device)

from skorch import NeuralNetClassifier

net = NeuralNetClassifier(
    module=model,
    criterion = nn.CrossEntropyLoss,
    optimizer=optim.Adam,
#     optimizer_momentum = 0.9,
    train_split=None,
    max_epochs=5,
    lr= learning_rate,
    warm_start = True,
    device = device
)

In [19]:
from sklearn.model_selection import KFold
from sklearn import metrics

# for epoch in range(5):
kf = KFold(n_splits=3, shuffle = True)
accuracies=[]
for train_index, test_index in kf.split(x):
    accuracies=[]
    (N, C1, C2, S) = x.shape # 17, 62, 512, 512
    x = x.reshape((17, 62, 512, 512))
    y = y.reshape((y.shape[0]))
    print(x.shape)
    print(y.shape)
    print(train_index)
    print(test_index)
    xk_train, xk_test = x[train_index], x[test_index]
    yk_train, yk_test = y[train_index], y[test_index]
    net.fit(xk_train,yk_train)
    y_pred = net.predict(xk_test)
    acc = metrics.accuracy_score(yk_test, y_pred)
    accuracies.append(acc)
    print('FinalAccuracy %.4f' % (np.mean(accuracies)))

torch.Size([17, 62, 512, 512])
torch.Size([17])
[ 0  1  2  4  5  8 10 11 13 14 15]
[ 3  6  7  9 12 16]
      7     1968.7339  122.4494
      8     1689.8179  126.3974
      9     1616.1758  125.6589
     10      767.3795  125.9576
     11      600.9477  133.7984
FinalAccuracy 0.6667
torch.Size([17, 62, 512, 512])
torch.Size([17])
[ 0  3  4  6  7  9 12 13 14 15 16]
[ 1  2  5  8 10 11]
     12      379.1372  129.4863
     13      496.5664  138.6512
     14      390.1367  156.9161
     15      118.3768  142.3620
     16      588.9722  138.8822
FinalAccuracy 0.5000
torch.Size([17, 62, 512, 512])
torch.Size([17])
[ 1  2  3  5  6  7  8  9 10 11 12 16]
[ 0  4 13 14 15]
     17     1112.1178  144.9799
     18      423.8275  145.4350
     19      541.4341  144.9793
     20      494.2856  150.5864
     21      443.8631  158.3380
FinalAccuracy 0.2000


In [26]:
(N, C1, C2, S) = x_Test.shape
x_Test = x_Test.reshape((N, S, C1, C2))
y_pred_test = net.predict(x_Test)
acc = metrics.accuracy_score(y_Test, y_pred_test)
print('TestAccuracy %.4f' % (acc))

TestAccuracy 0.2414
