# ０　元ネタ：ここを参考に実装しました
 - CNNの実装：Pytorch 学習済みモデルで識別[MNIST]
  - https://qiita.com/TKC-tkc/items/3b41620ecb9b22901413
 - IICの実装：相互情報量の最大化による教師なし学習手法IICの登場！
  - https://ai-scholar.tech/articles/treatise/iic-ai-367
 - アプリ開発：Jupyter上でDashを使えるjupyter_dash
  - https://qiita.com/OgawaHideyuki/items/725f4ffd93ffb0d30b6c
 - dash-canvas-ocr：手書き入力部分は、こちらを参考にしました
  - https://github.com/plotly/dash-sample-apps/tree/master/apps/dash-canvas-ocr
 - トンネリング：【Argo Tunnel】StreamlitアプリをGoogleColabから秒で外部公開する
  - https://www.space-i.com/post-blog/googlecola%E4%B8%8A%E3%81%8B%E3%82%89streamlit%E3%82%A2%E3%83%97%E3%83%AA%E3%82%92-cloudflare%E7%B5%8C%E7%94%B1%E3%81%A7%E7%A7%92%E3%81%A7%E5%A4%96%E9%83%A8%E5%85%AC%E9%96%8B%E3%81%99%E3%82%8B/

# １　必要なライブラリのインストール

In [None]:
!pip install jupyter_dash
!pip install dash_canvas dash-bootstrap-components

# ２　ライブラリのインポート

## ディープラーニング関係

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as f
import torchvision
from torchvision import datasets, transforms
from PIL import Image, ImageOps
from torchsummary import summary

## Webアプリ関係

In [None]:
import base64
from io import BytesIO
import pickle
from jupyter_dash import JupyterDash 
import numpy as np
import dash_html_components as html
import dash_core_components as dcc
import dash_table
from dash_canvas import DashCanvas
from dash_canvas.utils import array_to_data_url, parse_jsonstring
from dash.dependencies import Input, Output, State
from dash.exceptions import PreventUpdate
from dash_table.Format import Format, Scheme
import dash_bootstrap_components as dbc

# ３　学習済みモデルのロード（読み込み）

## ＣＮＮ学習済みモデルのロード

In [None]:
# モデルの定義
class MyNet(nn.Module):
    def __init__(self):
        super(MyNet,self).__init__()
        self.conv1 = nn.Conv2d(1,32,3,1)
        self.conv2 = nn.Conv2d(32,64,3,1)
        self.pool = nn.MaxPool2d(2,2)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(12*12*64,128)
        self.fc2 = nn.Linear(128,10)

    def forward(self,x):
        x = self.conv1(x)
        x = f.relu(x)
        x = self.conv2(x)
        x = f.relu(x)
        x = self.pool(x)
        x = self.dropout1(x)
        x = x.view(-1,12*12*64)
        x = self.fc1(x)
        x = f.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)

        return f.log_softmax(x, dim=1)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_cnn = 0
model_cnn = MyNet().to(device)
print(device)
print(model_cnn)
print(summary(model_cnn, (1, 28, 28)))

# 学習モデルをロードする
PATH = "/content/drive/MyDrive/0_HanPy41/0_Welcomeデモ/model_CNN.pt"
model_cnn.load_state_dict(torch.load(PATH, map_location=lambda storage, loc: storage))
model_cnn = model_cnn.eval()

## ＩＩＣ学習済モデルのロード（読み込み）

In [None]:
# ディープラーニングモデル
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

def weight_init(m):
    """重み初期化"""
    if isinstance(m, nn.Conv2d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.BatchNorm2d):
        init.normal_(m.weight.data, mean=1, std=0.02)
        init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        init.kaiming_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)

OVER_CLUSTRING_Rate = 10  # 多めに分類するoverclsuteringも用意する

class NetIIC(nn.Module):
    def __init__(self):
        super(NetIIC, self).__init__()

        self.conv1 = nn.Conv2d(1, 128, 5, 2, bias=False)
        self.bn1 = nn.BatchNorm2d(128)
        self.conv2 = nn.Conv2d(128, 128, 5, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 128, 5, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 256, 4, 1, bias=False)
        self.bn4 = nn.BatchNorm2d(256)

        # 0-9に対応すると期待したい10種類のクラス
        self.fc = nn.Linear(256, 10)

        # overclustering
        # 実際の想定よりも多めにクラスタリングさせることで、ネットワークで微細な変化を捉えられるようにする
        self.fc_overclustering = nn.Linear(256, 10*OVER_CLUSTRING_Rate)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x_prefinal = x.view(x.size(0), -1)
        y = F.softmax(self.fc(x_prefinal), dim=1)

        y_overclustering = F.softmax(self.fc_overclustering(
            x_prefinal), dim=1)  # overclustering

        return y, y_overclustering

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# モデル
model_IIC = NetIIC()
model_IIC.apply(weight_init)
model_IIC.to(device)

# 学習モデルをロードする
PATH = "/content/drive/MyDrive/0_HanPy41/0_Welcomeデモ/model_IIC.pt"
model_IIC.load_state_dict(torch.load(PATH, map_location=lambda storage, loc: storage))
model_IIC = model_IIC.eval()

# ４　Ｗｅｂアプリの実装

## 入力イメージの補正関数

In [None]:
def get_box(image):
    imgArray = np.asarray(image)

    lst_x, lst_y = [], []
    for i, row in enumerate(imgArray):
        if row.max() > 150:
            lst_y.append(i)
        for j, pix in enumerate(row):
            if pix > 150:
                lst_x.append(j)

    upper = np.array(lst_y).min()
    bottom = np.array(lst_y).max()
    left = np.array(lst_x).min()
    right = np.array(lst_x).max()

    return (left, upper, right + 1, bottom + 1), bottom - upper, right - left

def treat_image(image):
    box, h, w = get_box(image)
    im_crop = image.crop(box)
    im_base = im_crop.resize((int(w/h*160), 160))
    im_mask = Image.new("L", (200, 200), 0)
    im_mask.paste(im_base, (100 - int(w/h*80), 20))

    return im_mask

## Ｗｅｂアプリの本体

In [None]:
app = JupyterDash(__name__, 
                external_stylesheets=[dbc.themes.JOURNAL],
                meta_tags=[{
                    'name': 'viewport',
                    'content': 'width=device-width, initial-scale=1.0',
                    }]
                )

canvas_width = 224
canvas_height = 224

dic_IIC = {0:"6", 1:"2", 2:"7", 3:"1", 4:"4", 5:"5", 6:"3", 7:"9", 8:"0", 9:"8"}

app.layout = dbc.Container([
    # １行目
    dbc.Row([
        dbc.Col([
            html.H4("数字識別ＡＩ対決")
        ], width="auto"),
    ], justify = 'center', style = {'margin':'10px 0px 0px 0px'}),
    # ２行目
    dbc.Row([
        dbc.Col([
            dash_table.DataTable(
                id='id-table',
                style_cell={'textAlign':'center', 'fontSize':18, 'width':70, 'fontWeight':'bold'},
                columns=[{"name":"学習モデル","id":"モデル"}, 
                        {"name":" CNN ","id":"ＣＮＮ"},
                        {"name":" IIC ","id":"ＩＩＣ"},
                        ],
                data=[{"モデル":"識別結果","ＣＮＮ":"","ＩＩＣ":""}],
                # fill_width=True,
            ) 
        ], width="auto"),
    ], justify = 'center', style = {'margin':'20px 0px 0px 0px'}),
    # ３行目
    dbc.Row([
        dbc.Col([
            # Canvas
            DashCanvas(
                id="id-canvas",
                lineWidth=16,
                width=canvas_width,
                height=canvas_height,
                hide_buttons=[
                    "zoom",
                    "pan",
                    "line",
                    "pencil",
                    "rectangle",
                    "select",
                ],
                lineColor="black",
                goButtonTitle="　識　　別　",
            ),
        ], width="auto"),
    ], justify = 'center', style = {'margin':'20px 0px 0px 0px'}),
])


@app.callback(
    Output("id-table", "data"), 
    [Input("id-canvas", "json_data")],
)
def update_data(string):
    if string:
        try:
            mask = parse_jsonstring(string, shape=(canvas_height, canvas_width))
        except:
            return "Out of Bounding Box, click clear button and try again"
        mask = (~mask.astype(bool)).astype(int)

        image_string = array_to_data_url((255 * mask).astype(np.uint8))
        image = Image.open(BytesIO(base64.b64decode(image_string[22:])))
        temp  = image
        in_image = image.convert('L').resize((28,28))

        in_data = []
        for i in range(28):
            for j in range(28):
                temp = abs(in_image.getpixel((j,i)) - 255)/255
                if temp < 0.4:
                    temp = 0
                elif temp < 0.5:
                    temp *= 1.2 
                in_data.append(temp)

        image = ImageOps.invert(image)
        image = treat_image(image) # 入力イメージを整形
        in_image = image.resize((28,28))
        # データの前処理の定義(モデル生成の際と同じ平均値と標準偏差で正規化する)
        transform = transforms.Compose([
                                        transforms.ToTensor(),
                                        transforms.Normalize((0.1307,), (0.3081,))
                                        ])

        # 元のモデルに合わせて次元を追加
        image = transform(in_image).unsqueeze(0)

        # IIC
        img_in = image.to(device)
        outputs, _ = model_IIC(img_in)
        out = outputs.argmax(dim=1).cpu()
        v_IIC = dic_IIC[int(out)]

        # 予測を実施 CNN
        output = model_cnn(image.to(device))
        _, prediction = torch.max(output, 1)
        # 結果を出力
        v_CNN = str(prediction[0].item())
        # 識別結果の出力
        v_data=[{"モデル":"識別結果","ＣＮＮ":v_CNN, "ＩＩＣ":v_IIC}, ]
        return v_data
    else:
        raise PreventUpdate

# ５　Ｗｅｂアプリをローカル環境で実行

In [None]:
# app.run_server(mode="inline")

# ６　Ｗｅｂアプリのデプロイ

## externalモードでWebアプリを起動

In [None]:
app.run_server(mode='external', host="localhost", port=8030, debug=False)

## トンネリングで公開ＵＲＬを発行

In [None]:
# cloudflaredのインストール　＆　localhostの8030ポートのトンネリングした公開URLを発行
!wget https://bin.equinox.io/c/VdrWdbjqyF/cloudflared-stable-linux-amd64.deb
!dpkg -i cloudflared-stable-linux-amd64.deb
!cloudflared tunnel --url localhost:8030