In [2]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import time
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import quantile_transform, MinMaxScaler
from sklearn.metrics import confusion_matrix, roc_curve, auc, accuracy_score

from scipy.special import expit

## NN architecture

In [55]:
class InverseNN(nn.Module):
    def __init__(self, 
                 N_BLOCKS: int, 
                 N_KPOINTS: int, 
                 D_HIDDEN_BK: list, 
                 D_HIDDEN_FC: list,
                 D_OUT: int, 
                 P_DROPOUT: None):
        """
        Inverse NN, i.e. from frequencies to parameters:
        @param N_BLOCKS: number of input blocks, i.e., # of bands
        @param N_KPOINTS: number of K points in wach band
        @param D_HIDDEN_BK: list with the dimensions of each hidden layer in a block (BK)
        @param D_HIDDEN_FC: list with the dimensions of each hidden later in the FC
        @param D_OUT: dimension of the output layer in the FC, i.e., # of parameters
        @param P_DROPOUT: dropout probability
        """
        super(InverseNN, self).__init__()
        self.N_BLOCKS    = N_BLOCKS
        self.N_KPOINTS   = N_KPOINTS
        self.D_HIDDEN_BK = D_HIDDEN_BK
        self.D_HIDDEN_FC = D_HIDDEN_FC
        self.D_OUT       = D_OUT
        self.P_DROPOUT   = P_DROPOUT
        
        # Apply dropout?
        self.APPLY_DROPOUT = True
        if self.P_DROPOUT == None:
            self.APPLY_DROPOUT = False
        
    
    def build_LinearBlock(self):
        """ Builds a single input block (there is one per band). """
        layers = []
        in_ = self.N_KPOINTS
        for D_H in self.D_HIDDEN_BK:
            layers.append(nn.Linear(in_, D_H))
            layers.append(nn.ReLu())
            if self.APPLY_DROPOUT: 
                layers.append(nn.Dropout(self.P_DROPOUT))
            in_ = D_H

        return nn.Sequential(*layers)
    
    
    def build_FC(self):
        """ Builds the Fully Connected bit. """
        layers = []
        in_ = self.N_BLOCKS * self.D_HIDDEN_BK[-1]
        for D_H in self.D_HIDDEN_FC:
            layers.append(nn.Linear(in_, D_H))
            layers.append(nn.ReLu())
            if self.APPLY_DROPOUT: 
                layers.append(nn.Dropout(self.P_DROPOUT))
            in_ = D_H
            
        layers.append(nn.Linear(in_, D_OUT))
            
        return nn.Sequential(*layers)
    
    
    def forward_InverseSplitData(self, DATA):
        """ 
        Apply blocks to the input data:
        @param DATA: input data within a batch (BATCH_SIZE x N_KPOINTS)
        """
        fbands = torch.split(DATA, self.N_BLOCKS, dim=1)  # Frequency bands = columns
        assert(max(fbands[0].size()) == self.N_KPOINTS)
        
        outputs = []
        for fband in fbands:
            out_ = self.build_LinearBlock(fband)
            out_ = out_.view(out_.size(0), -1)
            outputs.append(out_)
        
        return torch.cat(outputs, dim=1)
    
    
    def forward(self, DATA):
        """ All the NN: Blocks + FC """
        x = self.forward_InverseSplitData(DATA)
        x = self.build_FC(x)
        
        return x

- Convert to .py