In [1]:
import os, sys
import torch
import importlib
sys.path.append('../')

### Basic Usage of torch Modules

In [2]:
# Instantiate a fully input convex neural net
from src.icnn import FICNN

net = FICNN(hidden_dims=[100,50,25])

X = torch.randn(128, 28, 28)
FX = net(X)


for p in net.parameters():
    print(p.shape, p.min())


torch.Size([100, 784]) tensor(-0.0357, grad_fn=<MinBackward1>)
torch.Size([100]) tensor(-0.0355, grad_fn=<MinBackward1>)
torch.Size([50, 784]) tensor(-0.0357, grad_fn=<MinBackward1>)
torch.Size([50, 100]) tensor(0., grad_fn=<MinBackward1>)
torch.Size([50]) tensor(-0.0339, grad_fn=<MinBackward1>)
torch.Size([25, 784]) tensor(-0.0357, grad_fn=<MinBackward1>)
torch.Size([25, 50]) tensor(0., grad_fn=<MinBackward1>)
torch.Size([25]) tensor(-0.0333, grad_fn=<MinBackward1>)
torch.Size([1, 784]) tensor(-0.0356, grad_fn=<MinBackward1>)
torch.Size([1, 25]) tensor(0., grad_fn=<MinBackward1>)
torch.Size([1]) tensor(-0.0282, grad_fn=<MinBackward1>)


In [4]:
import src.utils
importlib.reload(src.utils)
from src.utils import test_convexity, test_weights


test_weights(net)

test_convexity(net)

X = torch.randn(100, 28, 28)


Convexity never violated for 1000 tested pairs.


In [4]:
# Instantiate a partially input convex neural net
from src.icnn import PICNN

X = torch.randn([128, 30])
Y = torch.randn([128, 50])

net = PICNN(x_dim = 30, y_dim = 50)

FXY = net(X,Y)

print(FXY.shape)


torch.Size([128, 1])


### Optimizing with ICNN's

In [255]:
# Generate some data: true linear relationship
indim = 24
outdim = 1
n = 2000
X = torch.randn(n,indim, requires_grad=True)
W = torch.randn(outdim,indim)
b = torch.randn(outdim)
ε = 1e-8*torch.randn(outdim)
Y = torch.relu(torch.matmul(X,W.t()) + b + ε )   # To make target non-negative, otherwise nonline and relu nets wont model it
Y = Y.detach()

print(b)
print(X.shape, Y.shape)

tensor([-0.1313])
torch.Size([2000, 24]) torch.Size([2000, 1])


In [256]:
class PositiveWeightClipper(object):
    def __init__(self, frequency=1):
        self.frequency = frequency

    def __call__(self, module):
        tol = 1e-8
        # filter the variables to get the ones you want
        if hasattr(module, 'weight_z') and module.weight_z is not None:
            w = module.weight_z.data
            w = w.clamp(tol,None)
            module.weight_z.data = w
#         if hasattr(module, 'weight_y'):
#             print(module.weight_y.data.min())

In [297]:
# Generate some data: true linear relationship
import src.icnn
importlib.reload(src.icnn)
importlib.reload(src.utils)
from src.icnn import FICNN
from src.utils import test_convexity, test_weights
from torch.nn import Linear, ReLU


# DEBUG test:
#f = torch.nn.Sequential(Linear(indim,10),ReLU(), Linear(10,5), ReLU(), Linear(5,outdim), ReLU())
f = FICNN(input_dim=indim, hidden_dims=[10,10,5], output_dim = 1, dropout=0, nonlin='relu')
optimizer = torch.optim.Adam(f.parameters(), lr=1e-2) # Seems like faster learning rate messes up things
                                                      # probably because of crude clipping
clipper = PositiveWeightClipper()


# create dataset and dataloaders
loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(X,Y), batch_size=128)

epochs = 10

for epoch in range(epochs):
    for it,(x,y) in enumerate(loader):
        net.zero_grad()
        loss = torch.norm(f(x) - y)**2/(2*y.shape[0])
        loss.backward()    
        optimizer.step()
        f.apply(clipper)
        print(it,loss.item())
        if it % 10 == 0:
            test_weights(f)
            test_convexity(f)
            pass


0 4.476366996765137
Convexity never violated for 1000 tested pairs.
1 4.300981044769287
2 3.226853609085083
3 4.03611946105957
4 3.279298782348633
5 2.8895609378814697
6 3.208811044692993
7 2.8981096744537354
8 3.266568422317505
9 2.2652294635772705
10 1.8202370405197144
Convexity never violated for 1000 tested pairs.
11 1.5617696046829224
12 1.915873646736145
13 2.138279676437378
14 1.8523263931274414
15 1.9442164897918701
0 1.728805422782898
Convexity violated! Curvature: tensor([[-4.7684e-07]])


ValueError: 

In [298]:
torch.cat([f(X),Y],1)[:10]

tensor([[5.0175, 5.2461],
        [0.0000, 0.0000],
        [3.1047, 2.3003],
        [2.1032, 0.0000],
        [0.4061, 0.0000],
        [1.2425, 0.0000],
        [4.4283, 4.8646],
        [0.2079, 0.0000],
        [0.5294, 0.0000],
        [2.0367, 0.0000]], grad_fn=<SliceBackward>)

In [299]:
# Inspect all parameters
print('{:10} {:>10} {:>10}'.format('Param Type', 'min val', 'max val'))
for m in f.modules():
    for t in ['bias', 'weight_y', 'weight_z']:
        if hasattr(m, t) and getattr(m,t) is not None:
            print('{:10} {:10.2f} {:10.2f}'.format(t,getattr(m,t).min().item(), getattr(m,t).max().item()))    
        

Param Type    min val    max val
bias            -0.12       0.27
weight_y        -0.32       0.33
bias            -0.08       0.32
weight_y        -0.33       0.32
weight_z         0.02       0.42
bias            -0.08       0.28
weight_y        -0.35       0.33
weight_z         0.07       0.40
bias             0.13       0.13
weight_y        -0.32       0.28
weight_z         0.12       0.56
