In [None]:
#------------------------
# Google Colab上でのみ実行
#------------------------
import time
%env TOKEN=*************************************
! git clone https://$$TOKEN@github.com/konnitiha3/MOD2NN.git

import sys
sys.path.append('/content/MOD2NN')

from google.colab import drive
drive.mount('/content/drive')

In [3]:
import time
import os
import json
import sys
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from Faraday.two_dim.module.lib.layers import *
from Faraday.two_dim.module.lib import regularizer

tf.random.set_seed(1)

print("TensorFlow:", tf.__version__)
print("Python:", sys.version)

plt.rcParams['font.size'] = 18

TensorFlow: 2.8.0
Python: 3.8.11 (default, Aug 16 2021, 12:04:33) 
[Clang 12.0.0 ]


In [4]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0

y_train = y_train.astype("float32")
y_test = y_test.astype("float32")

# Reserve 10,000 samples for validation
x_val = x_train[-5000:]
y_val = y_train[-5000:]
x_train = x_train[:-5000]
y_train = y_train[:-5000]

In [5]:
#@title デフォルトのタイトル テキスト
wavelength = 532.0e-9 #@param {type:"number"}
d = 1.0e-6 #@param {type:"number"}
n = 1.5 #@param {type:"number"}

In [6]:
def create_model(**kwargs):
    tf.random.set_seed(1)
    shape = (100, 100)
    inputs = tf.keras.Input((28, 28))
    theta = -2.79 * np.pi / 180
    eta = np.arctan(1.24 * np.pi/180)/2
    l1=1.0e-5
    z = kwargs["z"]
    print(kwargs)
    x = ImageResizing(shape)(inputs)
    x = ImageBinarization(0.5, 0.0, 1.0)(x)
    x = IntensityToElectricField(shape)(x)
    x = MO(shape, limitation='sin', theta=theta, eta=eta, kernel_regularizer=regularizer.ShiftL1Regularizer(l1, np.pi/2))(x)
    x = AngularSpectrum(shape, wavelength=wavelength, z=0.7e-3, d=d, n=1.51, method='expand')(x)
    x = MO(shape, limitation='sin', theta=theta, eta=eta, kernel_regularizer=regularizer.ShiftL1Regularizer(l1, np.pi/2))(x)
    x = AngularSpectrum(shape, wavelength=wavelength, z=z, d=d, n=1.00, method='expand')(x)
    # x = Polarizer(shape)(x)
    #x =ElectricFieldToIntensity(shape)(x)
    #x = MNISTFilter(shape)(x)
    x = FaradayRotationByStokes(shape)(x)
    # x = Argument(shape)(x)
    #x = MNISTDetector(10)(x)
    x = CircleOnCircumferenceDetector(10, 30, 5)(x)
    x = tf.keras.layers.Softmax()(x)
    model = tf.keras.Model(inputs, x)
    return model

In [None]:
z_list = np.array([0.3e-3, 0.4e-3, 0.5-3, 0.6e-3, 0.7e-3, 0.8e-3, 0.9e-3, 1.0e-3])

for i, z in enumerate(z_list):
    model = create_model(z=z)
    model.summary()
    model.compile(optimizer=tf.keras.optimizers.Adam(),
                  loss=tf.keras.losses.sparse_categorical_crossentropy,  # category: sparse_categorical_crossentropy
                  metrics=['accuracy'])

    epochs = 50
    early_stopping = tf.keras.callbacks.EarlyStopping(
        monitor='accuracy',
        min_delta=0.05,
        patience=2,
    )

    model_name = "20220802_" + str(i + 1)
    cholab_path = "/content/drive/MyDrive/D2NN/"
    checkpoint_path = cholab_path + "checkpoint/" + model_name + "/cp-{epoch:04d}.ckpt"
    checkpoint_dir = os.path.dirname(checkpoint_path)

    # チェックポイントコールバックを作る
    cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
                                                     save_weights_only=True,
                                                     verbose=1)

    logdir = os.path.join(cholab_path +"logs", model_name)
    tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)

    result = model.fit(x_train,
                       y_train,
                       batch_size=64,
                       epochs=epochs,
                       validation_data=(x_val, y_val),
                       callbacks=[cp_callback]
                       )

    path = cholab_path + "trained_model/"+ model_name
    model.save(path)

    df = pd.DataFrame(result.history)
    df.to_csv(path + "/history.csv")

    with open(path + "/config.json", 'w') as f:
        json.dump(model.get_config(), f, indent=4)