In [1]:
data_dir = '/workspace'
import keras
from keras import layers
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import os
import random
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

In [None]:
def load_training_labels(data_dir='./', num_samples=None, balanced=False):
    training_labels = pd.read_csv(os.path.join(data_dir, 'train_labels.csv'))
    training_labels['label'] = training_labels['label'].astype('bool')
    if num_samples is None:
        return training_labels.sample(frac=1).reset_index(drop=True)
    
    if balanced:
        pos = training_labels[training_labels['label']].sample(num_samples // 2)
        neg = training_labels[~training_labels['label']].sample(num_samples // 2)
        training_labels = pd.concat([pos, neg]).sample(frac=1).reset_index(drop=True)
    else:
        training_labels = training_labels.sample(num_samples).reset_index(drop=True)

    return training_labels


def get_training_images(training_labels, data_dir='./'):
    images = np.array(
        [keras.utils.img_to_array(keras.utils.load_img(os.path.join(data_dir, 'train', f'{id}.tif')))
         for id in training_labels['id']])
    return images


def get_test_images(data_dir='./'):
    test_image_files = [f for f in os.listdir(os.path.join(data_dir, "test")) if f.endswith(".tif")]
    test_ids = [Path(f).stem for f in test_image_files]
    test_images = np.array(
        [keras.utils.img_to_array(keras.utils.load_img(os.path.join(data_dir, 'test', f)))
         for f in test_image_files])
    return test_images, test_ids


def plot_training_history(history):
    # Plot the training and validation loss and accuracy
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(history.history['loss'], label='Training Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title('Loss')
    plt.legend()
    plt.subplot(1, 2, 2)
    plt.plot(history.history['accuracy'], label='Training Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.title('Accuracy')
    plt.legend()
    plt.show()


In [7]:
def train_model(model, model_name, X, y, epochs=100, batch_size=32, validation_split=0.2):
    checkpointer = keras.callbacks.ModelCheckpoint(
        f'{model_name}.best_weights.keras', 
        save_best_only=True,
        verbose=1,
        monitor='val_AUC',
        mode='max')
    early_stopping = keras.callbacks.EarlyStopping(
        patience=10,
        verbose=1,
        monitor='val_AUC',
        mode='max')

    history = model.fit(X, y, epochs=epochs, batch_size=batch_size, validation_split=validation_split, 
                        callbacks=[checkpointer, early_stopping])
    
    model.load_weights(f'{model_name}.best_weights.keras')
    plot_training_history(history)
    return model # , history


def evaluate_model_and_print_results(model, X_test, y_test):
    test_loss, test_accuracy, test_auc = model.evaluate(X_test, y_test)
    print(f'Test Loss: {test_loss}')
    print(f'Test Accuracy: {test_accuracy}')
    print(f'Test AUC: {test_auc}')


def generate_submission(model, test_images, test_ids, model_name):
    test_predictions = model.predict(test_images)
    submission = pd.DataFrame({"id": test_ids, "label": test_predictions.flatten()})
    output_file = f'submission_{model_name}.csv'
    submission.to_csv(output_file, index=False)
    print(f"Submission saved to {output_file}")