In [1]:
from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.autograd as autograd
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from numpy.random import randint
import math
from torch.nn import functional as F
import numpy as np
import time
import pdb

import utils
import models

In [2]:
class opt:
    train_path = '/datasets/tomodata/source_train.h5'
    test_path = '/datasets/tomodata/source_test.h5'
    numclass = 4
    fet_size = 400
    nepoch = 100
    cuda = True
    manualSeed = 9182
    batch_size = 32
    tomo_dim = (32,32,32)
    clfr_lr = 0.0001

In [3]:
if opt.manualSeed is None:
    opt.manualSeed = random.randint(1, 10000)
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
if opt.cuda:
    torch.cuda.manual_seed_all(opt.manualSeed)

cudnn.benchmark = True

if torch.cuda.is_available() and not opt.cuda:
    print("WARNING: You have a CUDA device, so you should probably run with --cuda")

Random Seed:  9182


In [4]:
data = utils.DATA_LOADER(opt)
print("# of training samples: ", data.ntrain)
print("# of test samples: ", data.ntest)
print("# of class: ", data.numclass)

# of training samples:  900
# of test samples:  100
# of class:  4


In [5]:
source_clsfr = models.src_CLFR(opt)
print(source_clsfr)

src_CLFR(
  (conv1): Sequential(
    (0): Conv3d(1, 8, kernel_size=(5, 5, 5), stride=(1, 1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool3d(kernel_size=1, stride=1, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv3d(8, 16, kernel_size=(5, 5, 5), stride=(1, 1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv3): Sequential(
    (0): Conv3d(16, 32, kernel_size=(4, 4, 4), stride=(1, 1, 1))
    (1): ReLU(inplace=True)
  )
  (conv4): Sequential(
    (0): Conv3d(32, 64, kernel_size=(4, 4, 4), stride=(1, 1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv5): Sequential(
    (0): Conv3d(64, 128, kernel_size=(2, 2, 2), stride=(1, 1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (final): Sequential(
    (0): Linear(in_features=128, 

In [6]:
cls_criterion = nn.CrossEntropyLoss()
input_feature = torch.FloatTensor(opt.batch_size, opt.tomo_dim[0], opt.tomo_dim[1], opt.tomo_dim[2])
input_label = torch.FloatTensor(opt.batch_size)

In [7]:
if opt.cuda :
    source_clsfr.cuda()
    cls_criterion.cuda()
    input_feature.cuda()
    input_label.cuda()

In [8]:
def sample():
    batch_feature, batch_label = data.next_batch(opt.batch_size)
    input_feature.copy_(batch_feature)
    input_label.copy_(batch_label)

In [9]:
optim_clfr = optim.Adam(source_clsfr.parameters(), lr=opt.clfr_lr)

In [11]:
for epoch in range(opt.nepoch):
    for i in range(0, data.ntrain, opt.batch_size):

        for p in source_clsfr.parameters():
            p.requires_grad = True

        sample()
        source_clsfr.zero_grad()
        input_featureV = Variable(input_feature).cuda()
        input_labelV = Variable(input_label.long()).cuda()

        logits = source_clsfr(input_featureV)
        loss = cls_criterion(logits, input_labelV)
        loss.backward()
        optim_clfr.step()

    source_clsfr.eval()
    for p in source_clsfr.parameters():
        p.requires_grad = False
    test_logits = source_clsfr(data.test_subtom.cuda())
    _, predicted = torch.max(test_logits, 1)
    c = (predicted == data.test_label.cuda()).sum()
    c = c.cpu().numpy()
    acc = c / data.ntest

    print('[%d/%d] clfr_loss: %.4f | test_acc: %.4f'% (epoch, opt.nepoch, loss.item(),acc))
    source_clsfr.train()

[0/100] clfr_loss: 1.3680 | test_acc: 0.3900
[1/100] clfr_loss: 1.0357 | test_acc: 0.4600
[2/100] clfr_loss: 0.8219 | test_acc: 0.6300
[3/100] clfr_loss: 0.7664 | test_acc: 0.8400
[4/100] clfr_loss: 0.5138 | test_acc: 0.8100
[5/100] clfr_loss: 0.3538 | test_acc: 0.8100
[6/100] clfr_loss: 0.2404 | test_acc: 0.8700
[7/100] clfr_loss: 0.1307 | test_acc: 0.8500
[8/100] clfr_loss: 0.1724 | test_acc: 0.9200
[9/100] clfr_loss: 0.1426 | test_acc: 0.9100
[10/100] clfr_loss: 0.0219 | test_acc: 0.9100
[11/100] clfr_loss: 0.1987 | test_acc: 0.9100
[12/100] clfr_loss: 0.0825 | test_acc: 0.9300
[13/100] clfr_loss: 0.0216 | test_acc: 0.9000
[14/100] clfr_loss: 0.0388 | test_acc: 0.9100
[15/100] clfr_loss: 0.0186 | test_acc: 0.9100
[16/100] clfr_loss: 0.2554 | test_acc: 0.9400
[17/100] clfr_loss: 0.1323 | test_acc: 0.8900
[18/100] clfr_loss: 0.0482 | test_acc: 0.9200
[19/100] clfr_loss: 0.0077 | test_acc: 0.9300
[20/100] clfr_loss: 0.0382 | test_acc: 0.9100
[21/100] clfr_loss: 0.0023 | test_acc: 0.930