In [1]:
# import tensorflow as tf
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras

.\Lib\site-packages\tensorflow_datasets\core\shuffle.py \
\# import resource

pip install protobuf==3.19.6 tensorflow-metadata==1.10.0 tensorflow_datasets==4.6.0 tensorflow-gpu==2.10.0

In [2]:
print(tf.__version__)
print(tfds.__version__)

2.10.0
4.6.0


In [3]:
print(tf.test.is_built_with_cuda())

True


In [4]:
def prepare_tf_data(Train_size = 32, Test_size = 16, img_size = 224, random_ratio = 0.1):
    (train_ds, test_ds), ds_info = tfds.load(
        'cifar10',
        split=['train', 'test'],
        as_supervised=True,
        with_info=True
    )
    
    def preprocess(image, label):
        image = tf.cast(image, tf.float32) / 255.0
        image = tf.image.resize(image, (img_size, img_size))
        return image, label
    
    def augment(image, label):
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_brightness(image, random_ratio)
        image = tf.image.random_contrast(image, 1 - random_ratio, 1 + random_ratio)
        return image, label
    

    AUTOTUNE = tf.data.AUTOTUNE
    
    train_ds = (train_ds
                .map(preprocess, num_parallel_calls=AUTOTUNE)
                .map(augment, num_parallel_calls=AUTOTUNE)
                .shuffle(10000)
                .batch(Train_size)
                .prefetch(AUTOTUNE))
    
    test_ds = (test_ds
               .map(preprocess, num_parallel_calls=AUTOTUNE)
               .batch(Test_size)
               .prefetch(AUTOTUNE))
    
    return train_ds, test_ds

In [51]:
train_ds, test_ds = prepare_tf_data(img_size=32)

In [52]:
from models.vit import VisionTransformer

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

class PatchEmbed(layers.Layer):
    def __init__(self, patch_size, embed_dim):
        super().__init__()
        self.proj = layers.Conv2D(embed_dim, patch_size, strides=patch_size)
        
    def call(self, x):
        x = self.proj(x)
        return tf.reshape(x, [tf.shape(x)[0], -1, tf.shape(x)[-1]])

class VisionTransformer(keras.Model):
    def __init__(
        self, 
        input_shape,
        patch_size,
        num_classes,
        embed_dim,
        depth,
        num_heads,
        mlp_ratio=4,
        dropout=0.1
    ):
        super().__init__()
        
        num_patches = (input_shape[0] // patch_size) * (input_shape[1] // patch_size)
        self.patch_embed = PatchEmbed(patch_size, embed_dim)
        
        self.cls_token = self.add_weight(
            "cls_token", shape=[1, 1, embed_dim],
            initializer="zeros", trainable=True
        )
        self.pos_embed = self.add_weight(
            "pos_embed", shape=[1, num_patches + 1, embed_dim],
            initializer="zeros", trainable=True
        )
        
        # self.blocks = [
        #     TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
        #     for _ in range(depth)
        # ]

        self.blocks = tf.keras.Sequential([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])
        
        self.norm = layers.LayerNormalization()
        self.head = layers.Dense(num_classes)
        
    def call(self, x):
        B = tf.shape(x)[0]
        x = self.patch_embed(x)
        
        cls_tokens = tf.repeat(self.cls_token, B, axis=0)
        x = tf.concat([cls_tokens, x], axis=1)
        x = x + self.pos_embed
        
        for block in self.blocks:
            x = block(x)
            
        x = self.norm(x)
        x = self.head(x[:, 0])
        return x

class TransformerBlock(layers.Layer):
    def __init__(self, dim, num_heads, mlp_ratio=4, dropout=0):
        super().__init__()
        self.norm1 = layers.LayerNormalization()
        self.attn = layers.MultiHeadAttention(num_heads, dim//num_heads)
        self.norm2 = layers.LayerNormalization()
        self.mlp = keras.Sequential([
            layers.Dense(int(dim * mlp_ratio)),
            layers.Activation('gelu'),
            layers.Dense(dim)
        ])
        self.dropout = layers.Dropout(dropout)
        
    def call(self, x):
        x = x + self.attn(self.norm1(x), self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        x = self.dropout(x)
        return x

In [58]:
# image_size = 32 
# patch_size = 8  
# num_layers = 6 
# num_heads = 6   
# hidden_dim = 384
# mlp_dim = 1536   
# num_classes = 10 
# dropout = 0.08   
# attention_dropout = 0.08 

vit = VisionTransformer(
    input_shape=(32, 32, 3),
    patch_size=8,
    num_classes=10,
    embed_dim=384,
    depth=6,
    num_heads=6,
    mlp_ratio=4,
    dropout=0.1
)

# vit.summary()

In [65]:
initial_learning_rate = 1e-4
decay_steps = 5000
# lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
#     initial_learning_rate, decay_steps
# )

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=initial_learning_rate,
    decay_steps=decay_steps,
    decay_rate=0.9,  
    staircase=True   
)


In [66]:
import tensorflow_addons as tfa

vit.compile(
    optimizer = tfa.optimizers.AdamW(
    learning_rate=lr_schedule,
    weight_decay=0.0001
),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

# vit.compile(
#     optimizer=tf.keras.optimizers.AdamW(
#         learning_rate=lr_schedule,
#         weight_decay=0.0001
#     ),
#     loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
#     metrics=['accuracy']
# )

In [67]:
callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        'vit_cifar10_best',  
        save_best_only=True,
        monitor='val_accuracy'
    ),
    tf.keras.callbacks.EarlyStopping(
        monitor='val_accuracy',
        patience=5,
        restore_best_weights=True
    ),
    tf.keras.callbacks.TensorBoard(log_dir='./logs/vit_cifar10')
]


In [68]:
history = vit.fit(
    train_ds,
    validation_data=test_ds,
    epochs=10,
    callbacks=callbacks
)

Epoch 1/10



INFO:tensorflow:Assets written to: vit_cifar10_best\assets


INFO:tensorflow:Assets written to: vit_cifar10_best\assets


Epoch 2/10



INFO:tensorflow:Assets written to: vit_cifar10_best\assets


INFO:tensorflow:Assets written to: vit_cifar10_best\assets


Epoch 3/10



INFO:tensorflow:Assets written to: vit_cifar10_best\assets


INFO:tensorflow:Assets written to: vit_cifar10_best\assets


Epoch 4/10
Epoch 5/10



INFO:tensorflow:Assets written to: vit_cifar10_best\assets


INFO:tensorflow:Assets written to: vit_cifar10_best\assets


Epoch 6/10



INFO:tensorflow:Assets written to: vit_cifar10_best\assets


INFO:tensorflow:Assets written to: vit_cifar10_best\assets


Epoch 7/10



INFO:tensorflow:Assets written to: vit_cifar10_best\assets


INFO:tensorflow:Assets written to: vit_cifar10_best\assets


Epoch 8/10
Epoch 9/10
Epoch 10/10



INFO:tensorflow:Assets written to: vit_cifar10_best\assets


INFO:tensorflow:Assets written to: vit_cifar10_best\assets




In [None]:
test_loss, test_accuracy = vit.evaluate(test_ds)
print(f"Test accuracy: {test_accuracy:.4f}")

In [87]:
from models.resnet import ResNet18

In [88]:
class ResBlock(tf.keras.layers.Layer):
   def __init__(self, filters, strides=1):
       super().__init__()
       self.conv1 = tf.keras.layers.Conv2D(filters, 3, strides=strides, padding='same')
       self.bn1 = tf.keras.layers.BatchNormalization()
       self.conv2 = tf.keras.layers.Conv2D(filters, 3, padding='same') 
       self.bn2 = tf.keras.layers.BatchNormalization()

       if strides != 1:
           self.downsample = tf.keras.Sequential([
               tf.keras.layers.Conv2D(filters, 1, strides=strides),
               tf.keras.layers.BatchNormalization()
           ])
       else:
           self.downsample = None

   def call(self, inputs):
       identity = inputs
       
       x = self.conv1(inputs)
       x = self.bn1(x)
       x = tf.nn.relu(x)
       
       x = self.conv2(x)
       x = self.bn2(x)

       if self.downsample is not None:
           identity = self.downsample(inputs)
           
       x += identity
       return tf.nn.relu(x)

class ResNet18_v2(tf.keras.Model):
    def __init__(self, num_classes=10):
        super().__init__()
        
        self.conv1 = tf.keras.layers.Conv2D(64, 7, strides=2, padding='same')
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.pool1 = tf.keras.layers.MaxPool2D(3, strides=2, padding='same')
        
        # 直接使用多个 ResBlock 实例，不用 Sequential
        self.blocks1 = [ResBlock(64) for _ in range(2)]
        self.blocks2 = [ResBlock(128, strides=2)] + [ResBlock(128) for _ in range(1)]
        self.blocks3 = [ResBlock(256, strides=2)] + [ResBlock(256) for _ in range(1)]
        self.blocks4 = [ResBlock(512, strides=2)] + [ResBlock(512) for _ in range(1)]
        
        self.avgpool = tf.keras.layers.GlobalAveragePooling2D()
        self.fc = tf.keras.layers.Dense(num_classes)
    
    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = tf.nn.relu(x)
        x = self.pool1(x)
        
        for block in self.blocks1:
            x = block(x)
        for block in self.blocks2:
            x = block(x)
        for block in self.blocks3:
            x = block(x)
        for block in self.blocks4:
            x = block(x)
        
        x = self.avgpool(x)
        x = self.fc(x)
        return x

In [91]:
res = ResNet18_v2(num_classes=10)

initial_learning_rate = 5e-5
decay_steps = 5000
# lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
#     initial_learning_rate, decay_steps
# )

lr_schedule_res = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=initial_learning_rate,
    decay_steps=decay_steps,
    decay_rate=0.9,  
    staircase=True  
)

res.compile(
    optimizer = tfa.optimizers.AdamW(
    learning_rate=lr_schedule_res,
    weight_decay=0.0001
),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

callbacks_res = [
    tf.keras.callbacks.ModelCheckpoint(
        'res_cifar10_best',  
        save_best_only=True,
        monitor='val_accuracy'
    ),
    tf.keras.callbacks.EarlyStopping(
        monitor='val_accuracy',
        patience=10,
        restore_best_weights=True
    ),
    tf.keras.callbacks.TensorBoard(log_dir='./logs/res_cifar10')
]

history_res = res.fit(
    train_ds,
    validation_data=test_ds,
    epochs=50,
    callbacks=callbacks_res
)

Epoch 1/50



INFO:tensorflow:Assets written to: res_cifar10_best\assets


INFO:tensorflow:Assets written to: res_cifar10_best\assets


Epoch 2/50



INFO:tensorflow:Assets written to: res_cifar10_best\assets


INFO:tensorflow:Assets written to: res_cifar10_best\assets


Epoch 3/50



INFO:tensorflow:Assets written to: res_cifar10_best\assets


INFO:tensorflow:Assets written to: res_cifar10_best\assets


Epoch 4/50



INFO:tensorflow:Assets written to: res_cifar10_best\assets


INFO:tensorflow:Assets written to: res_cifar10_best\assets


Epoch 5/50
Epoch 6/50



INFO:tensorflow:Assets written to: res_cifar10_best\assets


INFO:tensorflow:Assets written to: res_cifar10_best\assets


Epoch 7/50



INFO:tensorflow:Assets written to: res_cifar10_best\assets


INFO:tensorflow:Assets written to: res_cifar10_best\assets


Epoch 8/50



INFO:tensorflow:Assets written to: res_cifar10_best\assets


INFO:tensorflow:Assets written to: res_cifar10_best\assets


Epoch 9/50
Epoch 10/50



INFO:tensorflow:Assets written to: res_cifar10_best\assets


INFO:tensorflow:Assets written to: res_cifar10_best\assets


Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50



INFO:tensorflow:Assets written to: res_cifar10_best\assets


INFO:tensorflow:Assets written to: res_cifar10_best\assets


Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50


In [102]:
tf.keras.backend.clear_session()