In [None]:
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
import random
from tqdm import tqdm
from sklearn.metrics import accuracy_score
from sklearn.neighbors import KNeighborsClassifier
import matplotlib.pyplot as plt
from typing import List, Dict, Any, Tuple

class Config:
    def __init__(self):
        self.batch_size = 256
        self.feature_dim = 256
        self.pred_dim = 256
        self.learning_rate = 0.05
        self.momentum = 0.99
        self.weight_decay = 1e-4
        self.epochs = 20
        self.device = '/GPU:0' if tf.config.list_physical_devices('GPU') else '/CPU:0'

def augment(image, label):
    image = tf.image.random_crop(image, size=[28, 28, 1])
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    image = tf.image.random_brightness(image, max_delta=0.1)
    image = tf.image.random_contrast(image, lower=0.9, upper=1.1)
    return image, label

class Encoder(keras.Model):
    def __init__(self, feature_dim):
        super().__init__()
        self.encoder = keras.Sequential([
            layers.Conv2D(64, 3, 1, 'same', activation='relu'),
            layers.BatchNormalization(),
            layers.Conv2D(128, 3, 1, 'same', activation='relu'),
            layers.BatchNormalization(),
            layers.MaxPooling2D(2),
            layers.Conv2D(256, 3, 1, 'same', activation='relu'),
            layers.BatchNormalization(),
            layers.MaxPooling2D(2),
            layers.Flatten(),
            layers.Dense(1024, activation='relu'),
            layers.BatchNormalization(),
            layers.Dense(feature_dim),
            layers.BatchNormalization()
        ])

    def call(self, x):
        return self.encoder(x)

class Predictor(keras.Model):
    def __init__(self, feature_dim, pred_dim):
        super().__init__()
        self.predictor = keras.Sequential([
            layers.Dense(pred_dim, activation='relu'),
            layers.BatchNormalization(),
            layers.Dense(feature_dim)
        ])

    def call(self, x):
        return self.predictor(x)

class DirectPred(keras.Model):
    def __init__(self, feature_dim, pred_dim):
        super().__init__()
        self.encoder = Encoder(feature_dim)
        self.predictor = Predictor(feature_dim, pred_dim)
        self.target_encoder = Encoder(feature_dim)

        # Build the models
        dummy_input = tf.keras.Input(shape=(28, 28, 1))
        self.encoder(dummy_input)
        self.predictor(self.encoder(dummy_input))
        self.target_encoder(dummy_input)

        # Initialize target_encoder with encoder's parameters
        self.target_encoder.set_weights(self.encoder.get_weights())

    def call(self, x1, x2):
        z1 = self.encoder(x1)
        z2 = self.encoder(x2)
        p1 = self.predictor(z1)
        p2 = self.predictor(z2)
        with tf.stop_gradient(self.target_encoder):
            t1 = self.target_encoder(x1)
            t2 = self.target_encoder(x2)
        return p1, p2, t1, t2

@tf.function
def directpred_loss(p1, p2, t1, t2):
    loss = tf.reduce_mean(tf.square(p1 - t2)) + tf.reduce_mean(tf.square(p2 - t1))
    return loss

@tf.function
def train_step(model, optimizer, x1, x2):
    with tf.GradientTape() as tape:
        p1, p2, t1, t2 = model(x1, x2)
        loss = directpred_loss(p1, p2, t1, t2)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss

def update_target_encoder(model, momentum):
    for encoder_weight, target_weight in zip(model.encoder.weights, model.target_encoder.weights):
        target_weight.assign(momentum * target_weight + (1 - momentum) * encoder_weight)

def train(model, train_dataset, optimizer, config):
    for epoch in range(config.epochs):
        total_loss = 0
        num_batches = 0
        for x, _ in tqdm(train_dataset, desc=f"Epoch {epoch + 1}/{config.epochs}"):
            x1 = augment(x, None)[0]
            x2 = augment(x, None)[0]
            loss = train_step(model, optimizer, x1, x2)
            update_target_encoder(model, config.momentum)
            total_loss += loss
            num_batches += 1
        
        avg_loss = total_loss / num_batches
        print(f"Epoch {epoch + 1}, Loss: {avg_loss:.4f}")

        if (epoch + 1) % 10 == 0:
            linear_acc, knn_acc = evaluate(model, train_dataset, test_dataset)
            print(f"Linear Evaluation Accuracy: {linear_acc:.4f}")
            print(f"KNN Evaluation Accuracy: {knn_acc:.4f}")

def extract_features(model, dataset):
    features = []
    labels = []
    for x, y in dataset:
        feature = model.encoder(x)
        features.append(feature.numpy())
        labels.append(y.numpy())
    return np.concatenate(features), np.concatenate(labels)

def linear_evaluation(train_features, train_labels, test_features, test_labels):
    classifier = keras.Sequential([
        layers.Dense(10, activation='softmax')
    ])
    classifier.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    classifier.fit(train_features, train_labels, epochs=100, verbose=0)
    _, accuracy = classifier.evaluate(test_features, test_labels, verbose=0)
    return accuracy

def knn_evaluation(train_features, train_labels, test_features, test_labels, k=5):
    knn = KNeighborsClassifier(n_neighbors=k)
    knn.fit(train_features, train_labels)
    predictions = knn.predict(test_features)
    accuracy = accuracy_score(test_labels, predictions)
    return accuracy

def evaluate(model, train_dataset, test_dataset):
    train_features, train_labels = extract_features(model, train_dataset)
    test_features, test_labels = extract_features(model, test_dataset)
    
    linear_acc = linear_evaluation(train_features, train_labels, test_features, test_labels)
    knn_acc = knn_evaluation(train_features, train_labels, test_features, test_labels)
    
    return linear_acc, knn_acc

if __name__ == "__main__":
    config = Config()

    # Load MNIST dataset
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
    x_train = x_train.reshape(-1, 28, 28, 1).astype("float32") / 255
    x_test = x_test.reshape(-1, 28, 28, 1).astype("float32") / 255

    train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    train_dataset = train_dataset.shuffle(10000).batch(config.batch_size).prefetch(tf.data.AUTOTUNE)

    test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
    test_dataset = test_dataset.batch(config.batch_size).prefetch(tf.data.AUTOTUNE)

    with tf.device(config.device):
        model = DirectPred(config.feature_dim, config.pred_dim)
        
        # Ensure the model is built
        dummy_input = tf.keras.Input(shape=(28, 28, 1))
        model(dummy_input, dummy_input)

        optimizer = keras.optimizers.SGD(learning_rate=config.learning_rate, momentum=0.9)
        
        lr_schedule = keras.optimizers.schedules.CosineDecay(
            initial_learning_rate=config.learning_rate,
            decay_steps=config.epochs * len(list(train_dataset))
        )
        optimizer.learning_rate = lr_schedule

        print("Starting DirectPred training...")
        train(model, train_dataset, optimizer, config)

        print("\nFinal Evaluation:")
        linear_acc, knn_acc = evaluate(model, train_dataset, test_dataset)
        print(f"Linear Evaluation Accuracy: {linear_acc:.4f}")
        print(f"KNN Evaluation Accuracy: {knn_acc:.4f}")