In [1]:
%matplotlib inline
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

In [2]:
# Set up the GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


cuda:0


In [3]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

d1 = np.float32(np.random.randn(1,1,5,5,250))
d2 = np.float32(np.random.randn(1,1,5,5,255))

imgs    = [torch.tensor(d1), torch.tensor(d2)]
#print(imgs[1])

allLabels  = [torch.randint(4,(250,),dtype=torch.long), torch.randint(4,(255,),dtype=torch.long)]

#d3 = torch.randn(2,5,5,101)
#d4 = F.avg_pool3d(d3, (5,5,1), stride=None, padding=0)
#print(d4.shape)

In [4]:
def conv3x3x3(x,K):
    """3x3 convolution with padding"""
    return F.conv3d(x, K, padding=1)

def conv3x3x3T(x,K):
    """3x3 convolution transpose with padding"""
    #K = torch.transpose(K,0,1)
    return F.conv_transpose3d(x, K, padding=1)
        
        
dis = nn.CrossEntropyLoss()
def misfit(X,W,C):    
    n = W.shape
    X = X.view(-1,n[0])
    S = torch.matmul(X,W)
    return dis(S,C), S   

def getAccuracy(S,labels):
    _, predicted = torch.max(S.data, 1)
    total = labels.size(0)
    correct = (predicted == labels).sum().item()
    return correct/total

In [5]:
class ResNet(nn.Module):

    def __init__(self, h,NG):
        super().__init__()

        # network geometry
        self.NG       = NG
        # time step
        self.h        = h
        # coarsening and TV norm
        
        
    def forward(self,x,Kresnet):
    
        nt = len(Kresnet)
        
        # time stepping
        for j in range(nt):
            
            # First case - rsent style step
            if NG[0,j] == NG[1,j]: 
                #print(torch.norm(z))
                z  = conv3x3x3(x, Kresnet[j])
                z  = F.instance_norm(z)
                z  = F.relu(z)        
                z  = conv3x3x3T(z,Kresnet[j])
                x  = x - self.h*z
            # Change number of channels/resolution    
            else:
                z  = conv3x3x3(x, Kresnet[j])
                z  = F.instance_norm(z)
                x  = F.relu(z)
        
        # compress in x-y dimensions 
        #x = F.avg_pool3d(x, (5,5,1), stride=None, padding=0)
             
        return x #torch.transpose(p,0,1)
        
   

In [6]:
# initialize net and weights
h           = 1e0

# Network geometry
NG = [1,    16,     16,    16,  
      16,    16,     16,    16, 
      0,    0,     0,    0]

NG = np.reshape(NG,(4,-1))


net   = ResNet(h,NG)

nsteps = NG.shape[1]


Kresnet = []
for i in range(nsteps):  
    Ki  = nn.Parameter(torch.Tensor(np.asscalar(NG[1,i]), np.asscalar(NG[0,i]),3,3,3))
    stdv  = 1e-3
    Ki.data.uniform_(-stdv, stdv)    
    # Move to the GPU
    Ki.data = Ki.data.to(device)
    
    #print(torch.norm(Ki))
    Kresnet.append(Ki)
    
# weights for linear classifier    
W     = nn.Parameter(torch.Tensor(np.asscalar(NG[1,-1])*25,4))
stdv  = 1e-3
W.data.uniform_(-stdv, stdv)
    
# Move to GPU
net.to(device)
d1gpu = imgs[0].to(device)
W.data = W.data.to(device)
x = net(d1gpu,Kresnet)

In [7]:
# run the network on the GPU
imgs0, labels0 = imgs[0].to(device), allLabels[0].to(device)
x = net(imgs0,Kresnet)

loss,_ = misfit(x,W,labels0)
print(loss)   


tensor(1.3865, device='cuda:0')


In [8]:
import torch.optim as optim
optimizer = optim.SGD([{'params':Kresnet},{'params': W}], lr=1e-1, momentum=0.9)

# Print every _ iterations
p_iter = 1

# Run _ epochs
n_epoch = 20

In [9]:
for epoch in range(n_epoch):  # loop over the dataset multiple times

    running_loss = 0.0
    running_accuracy = 0.0
    
    print('Epoch   Iteration   Loss(run)   Acc(run)   Acc(val)')
    print('---------------------------------------------------')
    
    for i in range(len(imgs)):
        # get the inputs
        inputs = imgs[i] 
        labels = allLabels[i]
        inputs, labels = inputs.to(device), labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        x    = net(inputs,Kresnet)
        loss, Si = misfit(x,W,labels)
        loss.backward()
             
        optimizer.step()

        # print statistics
        accuracy = getAccuracy(Si,labels)
        running_loss     += loss.item()
        running_accuracy += accuracy
        if i % p_iter == (p_iter-1):    # print every p_iter mini-batches
            # compute validation accuracy
            #with torch.no_grad():
            #    #for data in testloader:
            #    dataiter = iter(testloader)
            #    inputsV, labelsV = dataiter.next()
            #    inputsV, labelsV = inputsV.to(device), labelsV.to(device)
            #    xV = net(inputsV,Kresnet)
            #    lossV, SiV = misfit(xV,W,labelsV)
            #    accuracyV  = getAccuracy(SiV,labelsV)

            accuracyV = 0
            print(' %2d      %5d        %.3f      %.3f      %.3f' %
                  (epoch + 1, i + 1, running_loss / p_iter, running_accuracy/p_iter, accuracyV))
            running_loss = 0.0
            running_accuracy = 0.0

print('Finished Training')

Epoch   Iteration   Loss(run)   Acc(run)   Acc(val)
---------------------------------------------------
  1          1        1.386      0.232      0.000
  1          2        1.393      0.243      0.000
Epoch   Iteration   Loss(run)   Acc(run)   Acc(val)
---------------------------------------------------
  2          1        1.384      0.264      0.000
  2          2        1.314      0.596      0.000
Epoch   Iteration   Loss(run)   Acc(run)   Acc(val)
---------------------------------------------------
  3          1        1.255      0.576      0.000
  3          2        1.163      0.714      0.000
Epoch   Iteration   Loss(run)   Acc(run)   Acc(val)
---------------------------------------------------
  4          1        1.058      0.736      0.000
  4          2        0.969      0.827      0.000
Epoch   Iteration   Loss(run)   Acc(run)   Acc(val)
---------------------------------------------------
  5          1        0.839      0.804      0.000
  5          2        0.784   