In [None]:
from dash.dependencies import Input, Output, State
from helper_functions import *
import dash
import dash_core_components as dcc
import dash_html_components as html

# Global variables
g_input_nodes = 5
g_hidden_nodes = (1)
g_output_nodes = 3

input_layer_dict = {}
for i in range(0,55,5):
        input_layer_dict[i] = i
#print(input_layer_dict)

hidden_layer_dict = {}
for i in range(1,11,1):
        hidden_layer_dict[i] = i
#print(hidden_layer_dict)

output_layer_dict = {}
for i in range(2,201,10):
        output_layer_dict[i] = i
#print(output_layer_dict)


## for box layout 
external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']
app = dash.Dash(__name__, external_stylesheets=external_stylesheets)
app.config['suppress_callback_exceptions'] = True

activation_functions = ['Sigmoid','Softmax', 'ReLu', 'Tanh', ] 
loss_functions = ['Multiclass Cross-Entropy', 'Binary Cross-Entropy', 'Maximum Likelihood', 'Log loss']
optimizers = ['Gradient Descent', 'SGD', 'RMSProp', 'Momentum', 'NAG', 'Adagrad', 'Adadelta', 'Adam']

#########################################################################################################################

###                                                     LAYOUT                                                         ## 

#########################################################################################################################


app.layout = html.Div(children = [ 
                                    html.Div([html.H1(children='Torchify')
                                             ], className = "row"),
                                              
                                    html.Div([html.Div([
                                        dcc.Markdown('#### **I. Layers**'),

                                        dcc.Markdown('###### **Input Layer**'),
                                        dcc.Markdown('**Input Nodes**'),
                                        dcc.Slider(id='input_nodes',
                                                   min=5,
                                                   max=50,
                                                   step=5,
                                                   value=12,
                                                   marks=input_layer_dict),
                                        
                                        html.Br(),
                                        html.Br(),
                                        html.Br(),
    
                                              
                                        dcc.Markdown('###### **Hidden Layer(s)**'),
                                              
                                        dcc.Slider(id='hidden_layers',
                                                   min=1,
                                                   max=10,
                                                   step=1,
                                                   value=1,
                                                   marks=hidden_layer_dict) ,                                             
                                        
                                        html.Br(),
                                        html.Br(),
                                        dcc.Markdown('**Hidden Nodes Per Each Layer**'),
                                        html.P(''' Define the number of hidden nodes per each layer 
                                        in a form of a list. 
                                        E.g. if you have 3 hidden layers: 30, 20, 8
                                        '''),
                                        
                                        dcc.Input(id='hidden_nodes', value='1', type='text'),
                                        
                                        html.Br(),
                                        html.Br(),

                                        dcc.Markdown('###### **Output Layer**'),
                                        dcc.Markdown('**Output Nodes**'),
                                        dcc.Input(id='output_nodes', value='3', type='number'),

                                        html.Br(),
                                        html.Br(),
                                        html.Br(),
                                        
                                        dcc.Markdown('#### **II. Activation Function**'),
                                        
                                        dcc.Dropdown(id='activations_dropdown',
                                                                   options=[{'label': i, 'value': i} for i in activation_functions],
                                                                   value = activation_functions[0]),
                                        
                                        html.Br(),
                                        html.Br(),

                                        dcc.Markdown('#### **III. Loss Function**'),
                                        dcc.Dropdown(id='loss_dropdown',
                                                                   options=[{'label': i, 'value': i} for i in loss_functions],
                                                                   value = loss_functions[0]) ,                                       
                                        html.Br(),
                                        html.Br(),

                                        dcc.Markdown('#### **IV. Optimizer**'),
                                        dcc.Dropdown(id='optimizer_dropdown',
                                                                   options=[{'label': i, 'value': i} for i in optimizers],
                                                                   value = optimizers[0]) ,                                            
                                        
                                        html.Br(),
                                        html.Br(),

                                        dcc.Markdown('#### **V. Code**'),
                                        html.Button('Generate code', id='code_generator'),
                                        
                                        html.Br(),
                                        html.Br(),
                                        
                                             ], className = "three columns"),
    
                                              
                                            html.Div([dcc.Graph(id='neural_nets') 
                                                         ], style={"height" : "800", "width" : "800"},
                                                     className = "nine columns")
                                                     ], className = "row"),
                       
                    
                                        html.Br(),
                                        html.Br(),
                                        dcc.Markdown(
                                            '```\n\nGenerated code appears here.\n\n```',
                                            id='the_code'),
])

#########################################################################################################################

###                                           REACTIVE ELEMENTS                                                        ## 

#########################################################################################################################                            

@app.callback(
    Output('neural_nets', 'figure'),
    [Input('input_nodes', 'value'), 
     Input('hidden_nodes', 'value'), 
    Input('output_nodes', 'value')])
def update_graph(input_nodes, hidden_nodes, output_nodes): 
    
    global g_input_nodes, g_hidden_nodes, g_output_nodes
    
    input_nodes = int(input_nodes)
    output_nodes = int(output_nodes)
    # convert the input of hidden nodes into a list of numbers
    hidden_nodes = list(map(int, [x for x in hidden_nodes.split(',')]))
    node_nbr_dict, node_list = create_node_nbr_dict(input_nodes, hidden_nodes, output_nodes)

    circle_dict_list = create_circles(node_nbr_dict, input_nodes)
    list_dict_list = create_lines(node_nbr_dict, node_list, circle_dict_list)
    final_shape_list = circle_dict_list + list_dict_list
    
    g_input_nodes = input_nodes
    g_hidden_nodes = hidden_nodes
    g_output_nodes = output_nodes
    
    return plot_neural_network(final_shape_list, input_nodes)

@app.callback(
    Output('the_code', 'children'),
    [Input('code_generator', 'n_clicks'),
    Input('activations_dropdown', 'value'),
    Input('loss_dropdown', 'value'),
    Input('optimizer_dropdown', 'value')])
def generate_code(n_click, activations_dropdown, loss_dropdown, optimizer_dropdown):

    global g_input_nodes, g_hidden_nodes, g_output_nodes
    
    if activations_dropdown == 'Sigmoid':
        activation_str = '        x = F.sigmoid(x)\n'
    elif activations_dropdown == 'Softmax':
        activation_str = '        x = F.softmax(x)\n'
    elif activations_dropdown == 'ReLu':
        activation_str = '        x = F.relu(x)\n'
    elif activations_dropdown == 'Tanh':
        activation_str = '        x = F.tanh(x)\n'
    
    code = '```\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\n\n'
    
    init_code = 'class TorchifiedNetwork(nn.Module):\n\n\n\n    def __init__(self):\n\n'
    init_code += '        super(TorchifiedNetwork, self).__init__()\n\n'
    forward_code = '    def forward(self, x):\n\n'
    layer_counter = 1
    if isinstance(g_hidden_nodes, list):
        for nodes in g_hidden_nodes:
            if layer_counter == 1:
                init_code += '        self.fc{} = nn.Linear({}, {})\n'.format(layer_counter,  g_input_nodes, nodes)
            else:
                init_code += '        self.fc{} = nn.Linear({}, {})\n'.format(layer_counter,  prev_nodes, nodes)
            forward_code += '        x = self.fc{}(x)\n'.format(layer_counter) + activation_str
            prev_nodes = nodes
            layer_counter += 1
        init_code += '        self.fc{} = nn.Linear({}, {})\n'.format(layer_counter,  prev_nodes, g_output_nodes)
        forward_code += '        x = self.fc{}(x)\n'.format(layer_counter) + activation_str
    elif isinstance(g_hidden_nodes, int):
        init_code += '        self.fc1 = nn.Linear({}, {})\n'.format(g_input_nodes, g_hidden_nodes)
        forward_code += '        x = self.fc1(x)\n' + activation_str
        init_code += '        self.fc2 = nn.Linear({}, {})\n'.format(g_hidden_nodes, g_output_nodes)
        forward_code += '        x = self.fc2(x)\n' + activation_str
    
    code += init_code
    code += '\n\n\n' + forward_code + '\n        return x\n\n\n\n'
    code += 'myFirstTorchifiedNetwork = TorchifiedNetwork()\n\n'
    
    
    if loss_dropdown == 'Multiclass Cross-Entropy':
        code += 'loss = nn.CrossEntropyLoss()\n\n'
    elif loss_dropdown == 'Binary Cross-Entropy':
        code += 'loss = nn.BCELoss()\n\n'
    elif loss_dropdown == 'Maximum Likelihood':
        code += 'loss = nn.NLLLoss()\n\n'
    elif loss_dropdown == 'Log loss':
        code += 'loss = nn.L1Loss()\n\n'

    if optimizer_dropdown == 'Gradient Descent':
        code += 'optimizer = torch.optim.ASGD(myFirstTorchifiedNetwork.parameters())\n'
    elif optimizer_dropdown == 'SGD':
        code += 'optimizer = torch.optim.SGD(myFirstTorchifiedNetwork.parameters())\n'
    elif optimizer_dropdown == 'RMSProp':
        code += 'optimizer = torch.optim.RMSprop(myFirstTorchifiedNetwork.parameters())\n'
    elif optimizer_dropdown == 'Adagrad':
        code += 'optimizer = torch.optim.Adagrad(myFirstTorchifiedNetwork.parameters())\n'
    elif optimizer_dropdown == 'Adadelta':
        code += 'optimizer = torch.optim.Adadelta(myFirstTorchifiedNetwork.parameters())\n'
    elif optimizer_dropdown == 'Adam':
        code += 'optimizer = torch.optim.Adam(myFirstTorchifiedNetwork.parameters())\n'
    
    return code + '\n\n```'

if __name__ == '__main__':
    app.run_server(debug=False, host='127.0.0.1', port = 8080)

 * Serving Flask app "__main__" (lazy loading)
 * Environment: production
   Use a production WSGI server instead.
 * Debug mode: off


 * Running on http://127.0.0.1:8080/ (Press CTRL+C to quit)
127.0.0.1 - - [08/Jan/2019 05:25:09] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [08/Jan/2019 05:25:11] "GET /_dash-layout HTTP/1.1" 200 -
127.0.0.1 - - [08/Jan/2019 05:25:11] "GET /_dash-dependencies HTTP/1.1" 200 -
127.0.0.1 - - [08/Jan/2019 05:25:11] "POST /_dash-update-component HTTP/1.1" 200 -
127.0.0.1 - - [08/Jan/2019 05:25:11] "POST /_dash-update-component HTTP/1.1" 200 -
