In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms

import numpy as np
class Network(nn.Module):
    # replica of Dr Chandra's code
    def __init__(self,input_shape,hidden_shape,output_shape):
        super().__init__()
        self.topology = (input_shape,hidden_shape,output_shape)
        self.l1 = nn.Linear(input_shape,hidden_shape)
        self.l1_act = nn.Sigmoid()
        self.l2 = nn.Linear(hidden_shape,output_shape)
        self.out_act = nn.Sigmoid()
        
        self.export_w = np.zeros(input_shape*hidden_shape+hidden_shape+hidden_shape*output_shape+output_shape)
    def forward(self,x):
        return self.out_act(self.l2(self.l1_act(self.l1(x))))
    def loss(self,output,target_label):
        # after checking with cs9444 notes, the lgmcmc loss is just a l2 loss
        return nn.MSELoss(output,target_label) # by default, pytorch uses mean
    def encode(self):
        # export weights for mcmc
        self.export_w = np.concatenate((
            self.l1.weight.data.ravel().numpy(),
            self.l1.bias.data.ravel().numpy(),
            self.l2.weight.data.ravel().numpy(),
            self.l2.bias.data.ravel().numpy(),
        ))
        return self.export_w
    def decode(self,w):
        # import weights from mcmc
        l1w = self.topology[0]*self.topology[1]
        l2w = self.topology[1]*self.topology[2]
        l2w_start = l1w+self.topology[1]
        print(len(w[0:l1w]))
        print(len(w[l1w:l1w+self.topology[1]]))
        print(len(w[l2w_start:l2w_start+self.topology[2]]))
        print(len(w[l2w_start+self.topology[2]:]))

        self.l1.weight.data = torch.from_numpy( w[0:l1w].reshape((self.topology[1],self.topology[0])))
        self.l1.bias.data   = torch.from_numpy( w[l1w:l1w+self.topology[1]])
        self.l2.weight.data = torch.from_numpy( w[l2w_start:l2w_start+l2w].reshape((self.topology[2],self.topology[1])))
        self.l2.bias.data   = torch.from_numpy( w[l2w_start+l2w:])

    def evaluate_proposal(self,data):
        
        return self(data)
class MCMC:
    pass

In [6]:
n = Network(5,3,2)
d = torch.from_numpy(np.array([1,2,3,4,5],dtype = 'float32'))
print(n.evaluate_proposal(d))
print(n(d))


tensor([0.4995, 0.5238], grad_fn=<SigmoidBackward>)
tensor([0.4995, 0.5238], grad_fn=<SigmoidBackward>)


In [32]:
n = Network(5,3,2)
w = n.encode()
print(n.l1.weight.data)
print(n.l1.bias.data)
print(n.l2.weight.data)
print(n.l2.bias.data)
print("------------------------")
print(n.l1.weight.data.data_ptr())
print(hex(id(n.l1.weight.data.numpy())))
print("------------------------")

n.decode(w)
print(n.l1.weight.data)
print(n.l1.bias.data)
print(n.l2.weight.data)
print(n.l2.bias.data)
print("------------------------")
print(n.l1.weight.data.data_ptr())
print(hex(id(n.l1.weight.data.numpy())))

tensor([[-0.1113, -0.2503, -0.1069,  0.4155, -0.1043],
        [-0.2677, -0.2647, -0.2123,  0.0729, -0.3985],
        [-0.2636, -0.3600,  0.1263,  0.0928, -0.1728]])
tensor([-0.1262, -0.0842, -0.3052])
tensor([[-0.2671, -0.5293, -0.4061],
        [ 0.4941,  0.2802, -0.3054]])
tensor([ 0.1598, -0.4857])
------------------------
94041948193600
0x7f291bc2de70
------------------------
15
3
2
6
tensor([[-0.1113, -0.2503, -0.1069,  0.4155, -0.1043],
        [-0.2677, -0.2647, -0.2123,  0.0729, -0.3985],
        [-0.2636, -0.3600,  0.1263,  0.0928, -0.1728]])
tensor([-0.1262, -0.0842, -0.3052])
tensor([[-0.2671, -0.5293, -0.4061],
        [ 0.4941,  0.2802, -0.3054]])
tensor([ 0.1598, -0.4857])
------------------------
94041948220192
0x7f291bc59f90


In [18]:
type(n.l1.weight.data)

torch.Tensor

In [16]:
[func for func in dir(torch.Tensor) if callable(getattr(torch.Tensor, func))]

['__abs__',
 '__add__',
 '__and__',
 '__array__',
 '__array_wrap__',
 '__bool__',
 '__class__',
 '__complex__',
 '__contains__',
 '__deepcopy__',
 '__delattr__',
 '__delitem__',
 '__dir__',
 '__div__',
 '__eq__',
 '__float__',
 '__floordiv__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__gt__',
 '__hash__',
 '__iadd__',
 '__iand__',
 '__idiv__',
 '__ifloordiv__',
 '__ilshift__',
 '__imod__',
 '__imul__',
 '__index__',
 '__init__',
 '__init_subclass__',
 '__int__',
 '__invert__',
 '__ior__',
 '__ipow__',
 '__irshift__',
 '__isub__',
 '__iter__',
 '__itruediv__',
 '__ixor__',
 '__le__',
 '__len__',
 '__long__',
 '__lshift__',
 '__lt__',
 '__matmul__',
 '__mod__',
 '__mul__',
 '__ne__',
 '__neg__',
 '__new__',
 '__nonzero__',
 '__or__',
 '__pos__',
 '__pow__',
 '__radd__',
 '__rdiv__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__reversed__',
 '__rfloordiv__',
 '__rmul__',
 '__rpow__',
 '__rshift__',
 '__rsub__',
 '__rtruediv__',
 '__setattr__',
 '__setitem__',
