In [3]:
import torch 
import numpy as np
import pickle
from torch import nn, optim
from torch.nn import functional as F
import data_utils
from random import seed

In [4]:
np.random.seed(1234)
seed(1234)
torch.manual_seed(1234)

<torch._C.Generator at 0x7f91d318e950>

In [5]:
# dataset path
data_path = "./data/var_u.mat" # Orig IB data

# Run on GPU if possible
cuda = torch.cuda.is_available() 
device = torch.device("cuda" if cuda else "cpu")
print("Using "+ str(device))

Using cuda


In [6]:
config = {
    "loss_function" : nn.BCEWithLogitsLoss(),
    "batch_size" : 256,
    "epochs" : 200,
    "lr" : 0.0004, 
    "layer_sizes" : [12,10,7,5,4,3,2]
}

In [7]:
# Load the data
X_train, X_test, y_train, y_test = data_utils.load_data(data_path)

# Prepare data for pytorch
train_loader = data_utils.create_dataloader(X_train, y_train, len(X_train))
test_loader = data_utils.create_dataloader(X_test, y_test, len(X_test))

full_X, full_y = np.concatenate((X_train, X_test)), np.concatenate((y_train, y_test))
eval_loader = data_utils.create_dataloader(full_X, full_y, len(full_X))

In [8]:
# What do collect during training.
# Gradients of weights for all layers: something like ib_model.h3.weight.grad
# Weights for each layer: ib_model.h3.weight
# Activity: Save it in the forward pass and keep track of batch things.

# Compute L2 norm of weights
# Mean of gradients
# Std of gradients 
# Activity <- needed for MI 

In [9]:
class IBNet(nn.Module):
    def __init__(self, cfg):
        super(IBNet, self).__init__()
        torch.manual_seed(1234)
        # See https://pytorch.org/docs/stable/nn.html
        self.linears = nn.ModuleList()
        
        self.num_layers = len(cfg["layer_sizes"])
        self.inp_size = cfg["layer_sizes"][0]
        
        h_in = self.inp_size
        for h_out in cfg["layer_sizes"][1:]:
            self.linears.append(nn.Linear(h_in, h_out))
            h_in = h_out
    
        
    def forward(self, x):
        activations = [] #TODO: Could be nicer
        for idx in range(self.num_layers-1):
            x = self.linears[idx](x)
            x = torch.tanh(x)
            if idx == self.num_layers-1:
                x = self.linears[-1](x)
            if not self.training: #Internal flag in model
                activations.append(x)
         
        return x, activations

In [10]:
ib_model = IBNet(config).to(device)
optimizer = optim.Adam(ib_model.parameters(), lr=config["lr"])

In [11]:
class Trainer:
    def __init__(self, cfg, model, optimizer):
        self.opt = optimizer
        self.loss_function = cfg["loss_function"]
        self.batch_size = cfg["batch_size"]
        self.epochs = cfg["epochs"]
        self.model = model
        self.hidden_activations = [] # index 1: epoch num, index2 : layer_num
        self.weights = dict()
        self.ws_grads = dict()
        
        
    def _get_activity(self, eval_loader):
        """
        After each epoch save the activation of each hidden layer
        """
        self.model.eval()
        eval_loss = 0
        with torch.no_grad(): # Speeds up very little by turning autograd engine off.
            for i, (data, label) in enumerate(eval_loader):# No need to loop.
                data, label= data.to(device), label.to(device)
                yhat, activations = self.model(data)
                eval_loss += self.loss_function(yhat, label).item()
                
        eval_loss /= len(eval_loader.dataset)
        print('Evaluation set loss: {:.7f}'.format(eval_loss))
        return activations
    
    
    def train(self, train_loader, test_loader, eval_loader):
        for epoch in range(1, self.epochs+1):
            ### START TRAIN ###
            self.model.train()
            train_loss = 0
            for idx, (train_data, label) in enumerate(train_loader): 
                train_data, label = train_data.to(device), label.to(device)
                
                yhat, _ = self.model(train_data)
                loss = self.loss_function(yhat, label)
                
                self.opt.zero_grad()
                loss.backward()
                train_loss += loss.item()
                self.opt.step()
                
            avg_train_loss = train_loss / len(train_loader.dataset)
            print('Epoch: {} Average loss: {:.7f}'.format(epoch, avg_train_loss))
            ### STOP TRAIN ###
                
            ### START VALIDATION ###
            self.model.eval()
            valid_loss = 0
            with torch.no_grad():
                for i, (val_data, label) in enumerate(test_loader):
                    val_data, label = val_data.to(device), label.to(device)
                    yhat, activations = self.model(val_data)
                    valid_loss += self.loss_function(yhat, label).item()

            valid_loss /= len(test_loader.dataset)
            print('Validation set loss: {:.7f}'.format(valid_loss))   
            ### END VALIDATION ###
            
            ### SAVE ACTIVATION ON FULL DATA ###
            self.hidden_activations.append(self._get_activity(eval_loader))


In [12]:
tr = Trainer(config, ib_model, optimizer)

In [13]:
len(train_loader.dataset)

3277

In [14]:
tr.train(train_loader, test_loader, eval_loader)

Epoch: 1 Average loss: 0.0002129
Validation set loss: 0.0008523
Evaluation set loss: 0.0001704
Epoch: 2 Average loss: 0.0002129
Validation set loss: 0.0008523
Evaluation set loss: 0.0001703
Epoch: 3 Average loss: 0.0002129
Validation set loss: 0.0008522
Evaluation set loss: 0.0001703
Epoch: 4 Average loss: 0.0002129
Validation set loss: 0.0008521
Evaluation set loss: 0.0001703
Epoch: 5 Average loss: 0.0002128
Validation set loss: 0.0008520
Evaluation set loss: 0.0001703
Epoch: 6 Average loss: 0.0002128
Validation set loss: 0.0008519
Evaluation set loss: 0.0001703
Epoch: 7 Average loss: 0.0002128
Validation set loss: 0.0008519
Evaluation set loss: 0.0001703
Epoch: 8 Average loss: 0.0002128
Validation set loss: 0.0008518
Evaluation set loss: 0.0001702
Epoch: 9 Average loss: 0.0002128
Validation set loss: 0.0008517
Evaluation set loss: 0.0001702
Epoch: 10 Average loss: 0.0002127
Validation set loss: 0.0008516
Evaluation set loss: 0.0001702
Epoch: 11 Average loss: 0.0002127
Validation set 

Epoch: 89 Average loss: 0.0002103
Validation set loss: 0.0008424
Evaluation set loss: 0.0001683
Epoch: 90 Average loss: 0.0002103
Validation set loss: 0.0008422
Evaluation set loss: 0.0001682
Epoch: 91 Average loss: 0.0002102
Validation set loss: 0.0008420
Evaluation set loss: 0.0001682
Epoch: 92 Average loss: 0.0002102
Validation set loss: 0.0008418
Evaluation set loss: 0.0001682
Epoch: 93 Average loss: 0.0002101
Validation set loss: 0.0008416
Evaluation set loss: 0.0001681
Epoch: 94 Average loss: 0.0002101
Validation set loss: 0.0008414
Evaluation set loss: 0.0001681
Epoch: 95 Average loss: 0.0002100
Validation set loss: 0.0008412
Evaluation set loss: 0.0001680
Epoch: 96 Average loss: 0.0002100
Validation set loss: 0.0008410
Evaluation set loss: 0.0001680
Epoch: 97 Average loss: 0.0002099
Validation set loss: 0.0008407
Evaluation set loss: 0.0001679
Epoch: 98 Average loss: 0.0002099
Validation set loss: 0.0008405
Evaluation set loss: 0.0001679
Epoch: 99 Average loss: 0.0002098
Valida

KeyboardInterrupt: 

In [44]:
len(tr.hidden_activations)

1

In [48]:
ib_model = IBNet().to(device)
optimizer = optim.Adam(ib_model.parameters(), lr=0.0004)
loss_function = nn.NLLLoss()

epochs = 1

for epoch in range(1, epochs + 1):
        ### Start train pass ###
        ib_model.train()
        train_loss = 0
        for idx, (data, label) in enumerate(train_loader):
            data = data.to(device)
            label = label.long().to(device)

            yhat = ib_model(data)
            loss = loss_function(torch.log(yhat), label) #NLLLoss needs log probas.
            
            optimizer.zero_grad()
            loss.backward()
            train_loss += loss.item()
            optimizer.step()
            
            # Get gradients for each layer
            print(ib_model.h3.weight) # for 3'rd hidden layer
            print(ib_model.h3.weight.detach().numpy().shape)
        avg_train_loss = train_loss / len(train_loader.dataset)
        print('Epoch: {} Average loss: {:.7f}'.format(epoch, avg_train_loss))
        ### End train pass ###

        ### Start evaluate on validation set ###
        ib_model.eval()
        valid_loss = 0
        with torch.no_grad():
            for i, (data, label) in enumerate(test_loader):
                data = data.to(device)
                label = label.long().to(device)
                
                yhat = ib_model(data)
                valid_loss += loss_function(yhat, label).item()
        
        valid_loss /= len(test_loader.dataset)
        print('Validation set loss: {:.7f}'.format(valid_loss))

Parameter containing:
tensor([[-0.2035,  0.2650, -0.3538,  0.2026, -0.3642, -0.0774,  0.1970],
        [-0.2346,  0.3401,  0.2999,  0.3037,  0.2033, -0.0376,  0.2969],
        [ 0.0186, -0.2844, -0.2035,  0.3039, -0.3554, -0.0389,  0.1780],
        [-0.0967, -0.0714, -0.1287,  0.0062,  0.0489, -0.2657,  0.3678],
        [-0.0112,  0.0172,  0.1437, -0.2345,  0.0525,  0.2868, -0.0176]],
       requires_grad=True)
(5, 7)
Parameter containing:
tensor([[-0.2039,  0.2654, -0.3542,  0.2022, -0.3638, -0.0778,  0.1974],
        [-0.2342,  0.3397,  0.3003,  0.3041,  0.2029, -0.0372,  0.2965],
        [ 0.0190, -0.2848, -0.2031,  0.3043, -0.3558, -0.0385,  0.1776],
        [-0.0971, -0.0710, -0.1291,  0.0058,  0.0493, -0.2661,  0.3682],
        [-0.0116,  0.0176,  0.1433, -0.2349,  0.0529,  0.2864, -0.0172]],
       requires_grad=True)
(5, 7)
Parameter containing:
tensor([[-0.2043,  0.2658, -0.3545,  0.2019, -0.3634, -0.0782,  0.1978],
        [-0.2338,  0.3393,  0.3006,  0.3044,  0.2025, -0.0368