In [33]:
import os
import argparse
import logging
import time

import matplotlib.pyplot as plt
import numpy as np
import h5py
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from sklearn.preprocessing import LabelEncoder

import torch
import torch.nn as nn
from torch.utils import data
import torchvision.transforms as transforms

from torchdiffeq import odeint_adjoint as odeint

# Model Parameters

In [2]:
# GPU device
device = 'cuda:0'

# Downsampling method to use
# conv: convolutional layers
# res: residual layers
downsampling_method = 'res'
num_channels = 64
num_epochs = 100

batch_time = 5
batch_size = 20
total_data = 5000
learning_rate = 1e-4

# Use CUDA when available
if torch.cuda.is_available():
    dtype = torch.cuda.FloatTensor
    torch.backends.cudnn.benchmark=True
else:
    dtype = torch.FloatTensor

# Building Blocks and Definition for an ODENet

In [75]:
def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

def norm(dim):
    return nn.GroupNorm(min(32, dim), dim)

class ConcatConv2d(nn.Module):

    def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
        super(ConcatConv2d, self).__init__()
        module = nn.ConvTranspose2d if transpose else nn.Conv2d
        self._layer = module(
            dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,
            bias=bias
        )

    def forward(self, t, x):
        tt = torch.ones_like(x[:, :1, :, :]) * t
        ttx = torch.cat([tt, x], 1)
        return self._layer(ttx)
    
class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)

class ResBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(ResBlock, self).__init__()
        self.norm1 = norm(inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.norm2 = norm(planes)
        self.conv2 = conv3x3(planes, planes)

    def forward(self, x):
        shortcut = x

        out = self.relu(self.norm1(x))

        if self.downsample is not None:
            shortcut = self.downsample(out)

        out = self.conv1(out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)

        return out + shortcut

class ODEBlock(nn.Module):

    def __init__(self, odefunc):
        super(ODEBlock, self).__init__()
        self.odefunc = odefunc
        self.integration_time = torch.tensor([0, 1]).float()

    def forward(self, x):
        self.integration_time = self.integration_time.type_as(x)
        out = odeint(self.odefunc, x, self.integration_time, rtol=args.tol, atol=args.tol)
        return out[1]

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value
        
class ODENet(nn.Module):
    
    def __init__(self, odefunc):
        super(ODENet, self).__init__()
        
        if downsampling_method is 'conv':
            self.downsampling_layers = [
                nn.Conv2d(1, num_channels, 3, 1),
                norm(num_channels),
                nn.ReLU(inplace=True),
                nn.Conv2d(num_channels, num_channels, 4, 2, 1),
                norm(num_channels),
                nn.ReLU(inplace=True),
                nn.Conv2d(num_channels, num_channels, 4, 2, 1),
            ]
        elif downsampling_method is 'res':
            self.downsampling_layers = [
                nn.Conv2d(1, num_channels, 3, 1),
                ResBlock(num_channels, num_channels, stride=2, downsample=conv1x1(num_channels, num_channels, 2)),
                ResBlock(num_channels, num_channels, stride=2, downsample=conv1x1(num_channels, num_channels, 2)),
            ]

        self.feature_layers = [ODEBlock(odefunc(num_channels))]
        
        self.fc_layers = [norm(64), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), 
                          Flatten(), nn.Linear(64, 10)]
        
        self._layers = nn.Sequential(
                        *self.downsampling_layers,
                        *self.feature_layers,
                        *self.fc_layers)
        
    def forward(self, x):
        return self._layers(x)

# Define different ODE Function Blocks

In [61]:
class ODEClassificationFunc(nn.Module):

    def __init__(self, dim):
        super(ODEClassificationFunc, self).__init__()
        self.norm1 = norm(dim)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm3 = norm(dim)
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(t, out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(t, out)
        out = self.norm3(out)
        return out

# Data Loading

Loading of SAR images and their respective labels.

In [88]:
class MSTARDataset(data.Dataset):
    'Characterizes a dataset for PyTorch'
    def __init__(self, file_name, transform=None):
        'Initialization'
        super(MSTARDataset, self).__init__()
        self.file = h5py.File(file_name, 'r')
        self.transform = transform
        self.length = self.file['images'].shape[0]

    def __len__(self):
        'Denotes the total number of samples'
        return self.length

    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        image = self.file['images'][index,:,:].astype('float32')
        encoded_label = self.file['encoded_labels'][index]
        #label = self.file['labels'][index]
    
        if self.transform is not None:
            self.transform(image)
            
        return image, encoded_label
    
def complex_abs(data, dim=-1, keepdim=True):
    print(data.shape)
    assert data.size(-1) == 2
    return (data ** 2).sum(dim=-1, keepdim=True).sqrt()    
    
transform = transforms.Compose(
    [transforms.Lambda(lambda tensor: complex_abs(tensor))])

trainloader = data.DataLoader(MSTARDataset('data/mstart_train_data.hdf5', transform), batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testloader = data.DataLoader(MSTARDataset('data/mstart_test_data.hdf5', transform), batch_size=batch_size,
                                         shuffle=False, num_workers=2)

# Model Setup

In [86]:
model = ODENet(ODEClassificationFunc)
model = model.to(device)
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss().to(device)
end = time.time()

# Training

In [87]:
for itr in range(0, num_epochs): 
    model.train()
    for batch_count, batch_data in enumerate(trainloader):
        optimizer.zero_grad()
        image_batch, label_true = batch_data
        image_batch = image_batch.to(device)
        
        label_pred = model(image_batch)
        
        loss = criterion(label_pred, label_true)
        loss.backward()
        optimizer.step()
    
    total_predict = 0
    total_correct = 0
    model.eval()
    for batch_count, batch_data in enumerate(testloader):
        with torch.no_grad():
            image_batch, label_true = batch_data
            image_batch = image_batch.to(device)
            
            label_pred = model(image_batch)
            
            total_predict   += label_true.size(0)
            total_correct += (label_pred == label_true).sum().item()

    print('Iter {:04d} | Total Correct {:04d} out of {:04d}'.format(itr, total_correct, total_predict))

    end = time.time()



[[-0.00516798  0.01387577  0.02311549 ... -0.0055943   0.00893019
  -0.02152561]
 [-0.02389601 -0.0281056  -0.00316086 ...  0.04302207 -0.00482417
  -0.05653864]
 [-0.00375925 -0.03050296 -0.00886127 ...  0.00298356 -0.02950739
  -0.06167291]
 ...
 [ 0.04684278  0.04201766  0.02161201 ... -0.02964322 -0.08212417
  -0.10332746]
 [ 0.03267672  0.0154662  -0.01498783 ... -0.06160478 -0.0968824
  -0.13067466]
 [-0.02965261 -0.023783   -0.05946792 ... -0.05895231 -0.11678518
  -0.1156453 ]]
[[-0.0087558  -0.03838169 -0.04578689 ... -0.01186337 -0.02341154
  -0.01902867]
 [ 0.02535102  0.02545847  0.01492356 ... -0.01699662 -0.0145135
  -0.01648998]
 [ 0.04728644  0.03865997  0.01525966 ... -0.02407691 -0.01951841
   0.01519415]
 ...
 [ 0.02741541  0.02692742  0.04524927 ...  0.02085725 -0.0121762
  -0.03953116]
 [ 0.03326316  0.01196783  0.01596735 ...  0.0190334  -0.01358725
  -0.04534331]
 [ 0.02632676 -0.0127333  -0.00112978 ... -0.00319404 -0.01245982
  -0.01990991]]
[[ 0.0176753  -0.03

TypeError: Traceback (most recent call last):
  File "/home/fallah.5/anaconda3/envs/mstar/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 138, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/fallah.5/anaconda3/envs/mstar/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 138, in <listcomp>
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "<ipython-input-85-023c1e72d7ee>", line 22, in __getitem__
    self.transform(image)
  File "/home/fallah.5/anaconda3/envs/mstar/lib/python3.6/site-packages/torchvision/transforms/transforms.py", line 49, in __call__
    img = t(img)
  File "/home/fallah.5/anaconda3/envs/mstar/lib/python3.6/site-packages/torchvision/transforms/transforms.py", line 283, in __call__
    return self.lambd(img)
  File "<ipython-input-85-023c1e72d7ee>", line 32, in <lambda>
    [transforms.Lambda(lambda tensor: complex_abs(tensor))])
  File "<ipython-input-85-023c1e72d7ee>", line 28, in complex_abs
    assert data.size(-1) == 2
TypeError: 'int' object is not callable


[[ 0.04231434  0.07339521  0.05024559 ...  0.01066443 -0.00728734
  -0.02187674]
 [ 0.01291047  0.05053226  0.03470858 ... -0.02431109 -0.04774467
  -0.04251335]
 [-0.03652437 -0.00620858 -0.00162225 ... -0.06948606 -0.07200991
  -0.03753733]
 ...
 [ 0.01597822 -0.00198883 -0.00359587 ... -0.02968001 -0.0279764
  -0.01430927]
 [-0.02515339  0.00129177  0.00579153 ... -0.05262707 -0.09510272
  -0.07532839]
 [-0.01996377 -0.01881683 -0.01990695 ... -0.03434342 -0.08274035
  -0.07062592]]
