In [29]:
from tensorflow.keras.datasets import mnist
from tensorflow.keras.preprocessing.image import ImageDataGenerator

from sklearn.model_selection import train_test_split

import numpy as np
np.random.seed = 42

import seaborn as sns

import pandas as pd

import matplotlib.pyplot as plt

import random

import tensorflow as tf
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import Dense, Conv2D, Flatten, MaxPooling2D, BatchNormalization, Input, Dropout
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import Adam
import tensorflow_datasets as tfds

from keras_tuner import HyperParameters, Hyperband

from os.path import isfile

import lzma

import pickle

import jax.numpy as jnp
import jax
from jax import vmap, pmap, lax

from flax import linen as nn
from flax.training import train_state
from flax import jax_utils
import optax

import time

import functools
from functools import partial

**Импорт данных**

In [2]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()

In [3]:
width, height, depth = X_train.shape[1], X_train.shape[2], 1

X_train = X_train.reshape(X_train.shape[0], width, height, depth)
X_train = X_train.astype(np.float32)

X_test = X_test.reshape(X_test.shape[0], width, height, depth)
X_test = X_test.astype(np.float32)

y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

In [4]:
W_SHIFT = [-1, 1]
H_SHIFT = [-1, 1]
ROT_ANGLE = 20
ZOOM_RANGE = [0.7, 1.2]

In [5]:
BATCH_SIZE = 32
P_WIDTH = 28
P_HEIGHT = 28
P_DEPTH = 1
N_CLASSES = 10
NUM_EPOCHS = 3

In [6]:
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, shuffle=True, stratify=y_train)

In [7]:
train_generator = ImageDataGenerator(
    width_shift_range=W_SHIFT, 
    height_shift_range=H_SHIFT,
    rotation_range=ROT_ANGLE,
    zoom_range=ZOOM_RANGE,
    samplewise_center=True,
    samplewise_std_normalization=True,
    rescale= 1.0/255.0
)

train_generator.fit(X_train)
train_iterator = train_generator.flow(X_train, y_train, batch_size=BATCH_SIZE)

In [8]:
val_generator = ImageDataGenerator(
    width_shift_range=W_SHIFT, 
    height_shift_range=H_SHIFT,
    rotation_range=ROT_ANGLE,
    zoom_range=ZOOM_RANGE,
    samplewise_center=True,
    samplewise_std_normalization=True,
    rescale= 1.0/255.0
)

val_generator.fit(X_val)
val_iterator = val_generator.flow(X_val, y_val, batch_size=BATCH_SIZE)

In [9]:
test_generator = ImageDataGenerator(
    samplewise_center=True,
    samplewise_std_normalization=True,
    rescale=1.0/255.0
)

test_generator.fit(X_test)
test_iterator = test_generator.flow(X_test, y_test, batch_size=BATCH_SIZE)

In [10]:
def make_dataset_from_iterator(steps_count = 0, jax = True, iterator = None):
    """Generates dataset from ImageDataGenerator"""

    iterator.reset()

    result = [[], []]

    if jax:
        result = {
            'image': [], 
            'label': []
        }

    for _ in range(steps_count):
        batch_ds = tf.data.Dataset.from_generator(
            lambda: iterator,
            output_types=(tf.float32, tf.float32),
            output_shapes=(
                [BATCH_SIZE, P_WIDTH, P_HEIGHT, P_DEPTH],
                [BATCH_SIZE, N_CLASSES]
            )
        )

        temp_ds_iterator = iter(batch_ds)
        batch = next(temp_ds_iterator)

        if jax:
            result['image'].append(tfds.as_numpy(batch[0]))
            result['label'].append(np.argmax(tfds.as_numpy(batch[1]), axis=1))
        else:
            result[0].append(np.float32(tfds.as_numpy(batch[0])))
            result[1].append(np.float32(tfds.as_numpy(batch[1])))
    
    if jax:
        result['image'] = jnp.float32(np.array(result['image']).reshape(steps_count * BATCH_SIZE, P_WIDTH, P_HEIGHT, P_DEPTH))
        result['label'] = jnp.float32(np.array(result['label']).flatten())
        return result
    else:
        return np.array(result[0]).reshape(steps_count * BATCH_SIZE, P_WIDTH, P_HEIGHT, P_DEPTH), np.array(result[1]).reshape(steps_count * BATCH_SIZE, N_CLASSES)


In [11]:
STEP_SIZE_TRAIN = train_iterator.n // train_iterator.batch_size
STEP_SIZE_VAL = val_iterator.n // val_iterator.batch_size
STEP_SIZE_TEST = test_iterator.n // test_iterator.batch_size

In [12]:
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' # Use 8 CPU devices

In [18]:
train_ds = None
val_ds = None
test_ds = None
datasets_list = []

if isfile("./data/datasets.xz"):
    with lzma.open("./data/datasets.xz", "rb") as m_file:
        datasets_list = pickle.load(m_file)
        train_ds = datasets_list[0]
        val_ds = datasets_list[1]
        test_ds = datasets_list[2]
else:
    train_ds = make_dataset_from_iterator(STEP_SIZE_TRAIN, True, train_iterator)
    datasets_list.append(train_ds)

    val_ds = make_dataset_from_iterator(STEP_SIZE_VAL, True, val_iterator)
    datasets_list.append(val_ds)

    test_ds = make_dataset_from_iterator(STEP_SIZE_TEST, True, test_iterator)
    datasets_list.append(test_ds)

    with lzma.open("./data/datasets.xz", "wb") as m_file:
        pickle.dump(datasets_list, m_file)


2021-12-05 10:34:04.809469: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2021-12-05 10:34:04.809553: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)
2021-12-05 10:34:04.809591: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (DESKTOP-LDSDKNA): /proc/driver/nvidia/version does not exist


In [19]:
type(train_ds)

dict

In [20]:
TRAIN_DS_SIZE = len(train_ds['image'])
LEARNING_RATE = 0.01
MOMENTUM = 0.9

In [32]:
@nn.compact
def __call__(self, x_in):
    x = nn.Conv(features=32, kernel_size=(3, 3), padding='SAME')(x_in)
    x = nn.relu(x)
    x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

    x = nn.Conv(features=64, kernel_size=(3, 3), padding='SAME')(x)
    x = nn.relu(x)
    x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

    x = x.reshape((x.shape[0], -1))
    x = nn.Dense(features=128)(x)
    x = nn.relu(x)
    
    x = nn.Dense(features=N_CLASSES)(x)
    x = nn.log_softmax(x)
    return x


In [47]:
print("Jax devices count: {}\nJax local devices count: {}".format(jax.device_count(), jax.local_device_count()))
mesh_devices = np.array(jax.local_devices())

Jax devices count: 8
Jax local devices count: 8


In [31]:
def cross_entropy_loss(logits, labels):
    one_hot_labels = jax.nn.one_hot(labels, num_classes=N_CLASSES)
    return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))


In [55]:
from jax.experimental.maps import mesh, xmap

In [56]:
def test(x):
    return jnp.sum(x)

In [60]:
in_ax = [['a', ...]]

In [63]:
custom_summ = xmap(test, in_axes=in_ax, out_axes=in_ax)

  warn("xmap is an experimental feature and probably has bugs!")


In [64]:
a = np.arange(1000)

b = custom_summ(a)


ValueError: xmap out_axes specification must be a tree prefix of the corresponding value, got specification [['a', ...]] for value tree PyTreeDef(*).