In [None]:
%matplotlib inline
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

In [None]:
# Set up the GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# functions to show an image


def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))


# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()

# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))


In [None]:
def conv3x3(x,K):
    """3x3 convolution with padding"""
    return F.conv2d(x, K, stride=1, padding=1)

def conv3x3T(x,K):
    """3x3 convolution transpose with padding"""
    #K = torch.transpose(K,0,1)
    return F.conv_transpose2d(x, K, stride=1, padding=1)

def projectTensor(K):
#     shape = K.shape
#     K = K.view(shape[0]*shape[1], shape[2]*shape[3])
#     K = K - K.mean(1).view(-1,1)
#     K = K.view(*shape)
    for i in range(K.shape[0]):
        for j in range(K.shape[1]):        
            K[i,j,:,:] -= K[i,j,:,:].mean()
    return K
        
def projectTorchTensor(K):   
    shape = K.data.shape
    K.data = K.data.view(shape[0]*shape[1], shape[2]*shape[3])
    K.data = K.data - K.data.mean(1).view(-1,1)
    K.data = K.data.view(*shape)
#     for i in range(K.data.shape[0]):
#         for j in range(K.data.shape[1]): 
#             K.data[i,j,:,:] -= K.data[i,j,:,:].mean()
    return K    
        
dis = nn.CrossEntropyLoss()
def misfit(X,W,C):
    n = W.shape
    X = X.view(-1,n[0])
    S = torch.matmul(X,W)
    return dis(S,C), S   

def getAccuracy(S,labels):
    _, predicted = torch.max(S.data, 1)
    total = labels.size(0)
    correct = (predicted == labels).sum().item()
    return correct/total

Test the projection

In [None]:
outChannels = 16 
inChannels = 3
K  = nn.Parameter(torch.Tensor(outChannels, inChannels,3,3))
#stdv    = 1e-2
#K.data.uniform_(-stdv, stdv)

D = np.random.randn(16,3,3,3)
D = np.float32(D)
D = projectTensor(D)
D = torch.from_numpy(D)
K.data = D 

In [None]:
class ResNet(nn.Module):

    def __init__(self, h,NG):
        super().__init__()

        # network geometry
        self.NG       = NG
        # time step
        self.h        = h
        # coarsening and TV norm
        self.coarsen  = nn.AvgPool2d(32)
        self.coarsen2 = nn.AvgPool2d(2)
        
        
    def forward(self,x,Kresnet):
    
        nt = len(Kresnet)
        
        # time stepping
        for j in range(nt):
            
            # First case - rsent style step
            if NG[0,j] == NG[1,j]: 
                #print(torch.norm(z))
                z  = conv3x3(x, Kresnet[j])
                z  = F.instance_norm(z)
                z  = F.relu(z)        
                z  = conv3x3T(z,Kresnet[j])
                x  = x - self.h*z
            # Change number of channels/resolution    
            else:
                z  = conv3x3(x, Kresnet[j])
                z  = F.instance_norm(z)
                x = F.relu(z)
                if NG[2,j] == 1:
                    x = self.coarsen2(x)
                    
            #q = self.coarsen(x)
            #q = q.view(-1, nc)
            #p = torch.cat((p,q),1)
             
        return x #torch.transpose(p,0,1)
        
   

In [None]:
# initialize net and weights
h           = 1e0

# Network geometry
NG = [3,    64,    64,    64,  64,    256,   256,  
      64,   64,    64,    64,  256,   256,   256,
      1,     0,     0,     0,   1,      0,     0]

NG = np.reshape(NG,(3,-1))
net   = ResNet(h,NG)

nsteps = NG.shape[1]


Kresnet = []
for i in range(nsteps):  
    if NG[0,i] == NG[1,i]:
        Ki  = nn.Parameter(torch.Tensor(np.asscalar(NG[1,i]), np.asscalar(NG[0,i]),3,3))
        D   = np.random.randn(np.asscalar(NG[0,i]), np.asscalar(NG[1,i]),3,3)*1e-3
        D   = np.float32(D)
        D   = projectTensor(D)
        D   = torch.from_numpy(D)
        Ki.data = D
    else:
        Ki = nn.Parameter(torch.Tensor(np.asscalar(NG[1,i]), np.asscalar(NG[0,i]) ,3,3))
        stdv  = 1e-3
        Ki.data.uniform_(-stdv, stdv)
        
    # Move to the GPU
    Ki.data = Ki.data.to(device)
    
    #print(torch.norm(Ki))
    Kresnet.append(Ki)
    
# weights for linear classifier    
W     = nn.Parameter(torch.Tensor(256*8*8,10))
stdv  = 1e-3
W.data.uniform_(-stdv, stdv)
    
# Move to GPU
net.to(device)
W.data = W.data.to(device)

In [None]:
torch.norm(Kresnet[1])**2

In [None]:
# run the network on the GPU
images, labels = images.to(device), labels.to(device)
x = net(images,Kresnet)
C = torch.randint(0,10,(4,),dtype=torch.long)
C = C.to(device)

loss,_ = misfit(x,W,C)
print(loss)    


In [None]:
x = x.cpu()
xnp = x.data.numpy()
xnp = np.reshape(xnp,(4,16*8,16*8))
im = plt.imshow(xnp[1,:,:])
plt.colorbar(im)

In [None]:
import torch.optim as optim
optimizer = optim.SGD([{'params':Kresnet},{'params': W}], lr=1e-4, momentum=0.9)

# Print every _ iterations
p_iter = 500

# Run _ epochs
n_epoch = 10

In [None]:
for epoch in range(n_epoch):  # loop over the dataset multiple times

    running_loss = 0.0
    running_accuracy = 0.0
    
    print('Epoch   Iteration   Loss(run)   Acc(run)   Acc(val)')
    print('---------------------------------------------------')
    
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        x    = net(inputs,Kresnet)
        loss, Si = misfit(x,W,labels)
        loss.backward()
             
        optimizer.step()

        # print statistics
        accuracy = getAccuracy(Si,labels)
        running_loss     += loss.item()
        running_accuracy += accuracy
        if i % p_iter == (p_iter-1):    # print every p_iter mini-batches
            # compute validation accuracy
            with torch.no_grad():
                #for data in testloader:
                dataiter = iter(testloader)
                inputsV, labelsV = dataiter.next()
                inputsV, labelsV = inputsV.to(device), labelsV.to(device)
                xV = net(inputsV,Kresnet)
                lossV, SiV = misfit(xV,W,labelsV)
                accuracyV  = getAccuracy(SiV,labelsV)

            
            print(' %2d      %5d        %.3f      %.3f      %.3f' %
                  (epoch + 1, i + 1, running_loss / p_iter, running_accuracy/p_iter, accuracyV))
            running_loss = 0.0
            running_accuracy = 0.0

print('Finished Training')