In [1]:
# Instalar dependências (execute apenas uma vez)
!pip install tensorflow tensorflow-datasets matplotlib

# Importar bibliotecas
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

# Função para verificar se a imagem é válida
def is_valid(image, label):
    return tf.reduce_all(tf.math.is_finite(tf.cast(image, tf.float32)))

# Carregar e preparar o dataset
try:
    dataset, info = tfds.load('cats_vs_dogs', with_info=True, as_supervised=True)
except Exception as e:
    print("Erro ao carregar o dataset:", e)
    print("Tente rodar novamente ou atualizar o tensorflow-datasets.")

train_dataset = dataset['train']
test_dataset = dataset['test'] if 'test' in dataset else train_dataset.take(1000)

IMG_SIZE = 224

def format_image(image, label):
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
    image = image / 255.0
    return image, label

# Filtrar imagens corrompidas e preparar batches
train_dataset = train_dataset.filter(is_valid).map(format_image).shuffle(1000).batch(32)
test_dataset = test_dataset.filter(is_valid).map(format_image).batch(32)

# Carregar modelo pré-treinado MobileNetV2
base_model = MobileNetV2(input_shape=(IMG_SIZE, IMG_SIZE, 3),
                         include_top=False,
                         weights='imagenet')
base_model.trainable = False

# Adicionar camadas de classificação
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(128, activation='relu')(x)
predictions = Dense(2, activation='softmax')(x)  # 2 classes: gato e cachorro

model = Model(inputs=base_model.input, outputs=predictions)

# Compilar o modelo
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.summary()

# Treinar o modelo
EPOCHS = 2  # Aumente se quiser
history = model.fit(train_dataset,
                    epochs=EPOCHS,
                    validation_data=test_dataset)

# Avaliar acurácia no dataset de teste
loss, accuracy = model.evaluate(test_dataset)
print(f"Acurácia: {accuracy*100:.2f}%")

# Visualizar curvas de treinamento
plt.plot(history.history['accuracy'], label='Treinamento')
plt.plot(history.history['val_accuracy'], label='Validação')
plt.xlabel('Épocas')
plt.ylabel('Acurácia')
plt.legend()
plt.show()

# Salvar modelo treinado (opcional)
# model.save("meu_modelo_transfer_learning.h5")

^C


KeyboardInterrupt: 