In [24]:
from keras.api.applications import EfficientNetB0
from keras import layers
import keras
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

In [25]:
BATCH_SIZE = 32
RESOLUTION = 224

realwaste = keras.utils.image_dataset_from_directory(
    "./data/realwaste",
    labels="inferred",
    label_mode="categorical",
    image_size=(RESOLUTION, RESOLUTION),
    batch_size=32
)

trashnet = keras.utils.image_dataset_from_directory(
    "./data/trashnet",
    labels="inferred",
    label_mode="categorical",
    image_size=(RESOLUTION, RESOLUTION),
    batch_size=32
)

Found 3587 files belonging to 6 classes.
Found 2527 files belonging to 6 classes.


In [26]:
def split(dataset, train_pct):
        size = len(list(dataset.as_numpy_iterator()))
        train = dataset.take(int(train_pct * size))
        validation = dataset.skip(int(train_pct * size))
        return train, validation

training_dataset, validation_dataset = split(realwaste, 0.8)

In [27]:
augmentation_layers = [
    # Lighting variations
    layers.RandomBrightness(factor=(-0.2, 0.2)),
    # Blurring
    layers.GaussianNoise(stddev=0.2),
    # Distortions
    layers.RandomRotation(factor=0.1, fill_mode='nearest'),
    layers.RandomFlip(mode='horizontal'),
    layers.RandomZoom(height_factor=(-0.2, 0.2), width_factor=(-0.2, 0.2)),
    layers.RandomTranslation(height_factor=0.1, width_factor=0.1, fill_mode='nearest'),
    # Color variations
    layers.RandomContrast(factor=(0.8, 1.2)),
]

def augment(image):
    for layer in augmentation_layers:
        image = layer(image)
    return image

In [28]:
NUM_CLASSES = 6

def preprocess_augment(image, label):
    image = tf.image.resize(image, (RESOLUTION, RESOLUTION))
    image = augment(image)
    return image, label

def resize(image, label):
    image = tf.image.resize(image, (RESOLUTION, RESOLUTION))
    return image, label

# Preprocess training
training_dataset = training_dataset.map(preprocess_augment, num_parallel_calls=tf.data.AUTOTUNE)
training_dataset = training_dataset.prefetch(tf.data.AUTOTUNE)

validation_dataset = validation_dataset.map(resize, num_parallel_calls=tf.data.AUTOTUNE)

trashnet = trashnet.map(resize, num_parallel_calls=tf.data.AUTOTUNE)
trashnet = trashnet.prefetch(tf.data.AUTOTUNE)

In [29]:
def build_feature_extractor(input_shape):
    base_model = EfficientNetB0(include_top=False, weights=None, input_shape=input_shape)
    x = base_model.output
    x = layers.GlobalAveragePooling2D()(x)
    return keras.Model(base_model.input, x)

# Label Classifier
def build_label_classifier(feature_extractor):
    x = feature_extractor.output
    x = layers.Dense(512, activation='relu')(x)
    x = layers.Dropout(0.5)(x)
    x = layers.Dense(NUM_CLASSES, activation='softmax')(x)
    return keras.Model(feature_extractor.input, x)

# Domain Classifier
def build_domain_classifier(feature_extractor):
    x = feature_extractor.output
    x = layers.Dense(512, activation='relu')(x)
    x = layers.Dropout(0.5)(x)
    x = layers.Dense(1, activation='sigmoid')(x)
    return keras.Model(feature_extractor.input, x)

In [30]:
def build_dann_model(feature_extractor, label_classifier, domain_classifier):
    input_image = layers.Input(shape=(RESOLUTION, RESOLUTION, 3))
    
    # Feature extraction
    features = feature_extractor(input_image)
    
    # Label prediction
    label_pred = label_classifier(input_image)
    
    # Domain prediction
    domain_pred = domain_classifier(input_image)
    
    return keras.Model(input_image, [label_pred, domain_pred])

In [None]:
label_loss_fn = tf.keras.losses.CategoricalCrossentropy()
domain_loss_fn = tf.keras.losses.BinaryCrossentropy()
optimizer = keras.optimizers.Adam(learning_rate=0.0001)

# Define the metrics
label_accuracy = tf.keras.metrics.CategoricalAccuracy()
domain_accuracy = tf.keras.metrics.BinaryAccuracy()

@tf.function
def train_step(source_images, source_labels, target_images):
    with tf.GradientTape() as tape:
        source_label_preds, source_domain_preds = dann_model(source_images)
        source_label_loss = label_loss_fn(source_labels, source_label_preds)
        source_domain_loss = domain_loss_fn(tf.zeros_like(source_domain_preds), source_domain_preds)
        
        _, target_domain_preds = dann_model(target_images)
        target_domain_loss = domain_loss_fn(tf.ones_like(target_domain_preds), target_domain_preds)
        
        total_loss = source_label_loss + source_domain_loss + target_domain_loss
    
    gradients = tape.gradient(total_loss, dann_model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, dann_model.trainable_variables))
    
    label_accuracy.update_state(source_labels, source_label_preds)
    domain_accuracy.update_state(tf.zeros_like(source_domain_preds), source_domain_preds)
    domain_accuracy.update_state(tf.ones_like(target_domain_preds), target_domain_preds)

# Training loop
def train_dann(training_dataset, trashnet, epochs):
    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")
        
        label_accuracy.reset_state()
        domain_accuracy.reset_state()
        
        for source_images, source_labels in training_dataset:
            target_images, _ = next(iter(trashnet))
            train_step(source_images, source_labels, target_images)
        
        print(f"Label Accuracy: {label_accuracy.result().numpy()}, Domain Accuracy: {domain_accuracy.result().numpy()}")

In [None]:
feature_extractor = build_feature_extractor((RESOLUTION, RESOLUTION, 3))
label_classifier = build_label_classifier(feature_extractor)
domain_classifier = build_domain_classifier(feature_extractor)
dann_model = build_dann_model(feature_extractor, label_classifier, domain_classifier)

dann_model.compile(optimizer=optimizer)

train_dann(training_dataset, trashnet, epochs=10)

Epoch 1/10
Label Accuracy: 0.25, Domain Accuracy: 0.5001736283302307
Epoch 2/10
Label Accuracy: 0.24166665971279144, Domain Accuracy: 0.5055555701255798
Epoch 3/10
Label Accuracy: 0.24409721791744232, Domain Accuracy: 0.510937511920929
Epoch 4/10
Label Accuracy: 0.24930556118488312, Domain Accuracy: 0.6421874761581421
Epoch 5/10
Label Accuracy: 0.2777777910232544, Domain Accuracy: 0.770312488079071
Epoch 6/10
Label Accuracy: 0.3027777671813965, Domain Accuracy: 0.7741319537162781
Epoch 7/10
Label Accuracy: 0.2993055582046509, Domain Accuracy: 0.7769097089767456
Epoch 8/10
Label Accuracy: 0.3402777910232544, Domain Accuracy: 0.8092013597488403
Epoch 9/10
Label Accuracy: 0.3409722149372101, Domain Accuracy: 0.8192708492279053
Epoch 10/10
Label Accuracy: 0.3583333194255829, Domain Accuracy: 0.8496527671813965
