In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.insert(0, "../")

from autogluon.vision import ImagePredictor, ImageDataset
import numpy as np
import pandas as pd

pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', None)

## Read data

In [None]:
# path to data
CIFAR_10_DATA_PATH = "/datasets/uly/ood-data/cifar10_png/"
CIFAR_100_DATA_PATH = "/datasets/uly/ood-data/cifar100_png/"
MNIST_DATA_PATH = "/datasets/uly/ood-data/mnist_png/"
FASHION_MNIST_DATA_PATH = "/datasets/uly/ood-data/fashion_mnist_png/"

# read data from root folder
cifar_10_train_dataset, _, cifar_10_test_dataset = ImageDataset.from_folders(root=CIFAR_10_DATA_PATH)
cifar_100_train_dataset, _, cifar_100_test_dataset = ImageDataset.from_folders(root=CIFAR_100_DATA_PATH)
mnist_train_dataset, _, mnist_test_dataset = ImageDataset.from_folders(root=MNIST_DATA_PATH)
fashion_mnist_train_dataset, _, fashion_mnist_test_dataset = ImageDataset.from_folders(root=FASHION_MNIST_DATA_PATH)

In [None]:
# dictionary to store data path and model

data_model_dict = {
    "cifar-10": {
        "train_data": cifar_10_train_dataset,
        "test_data": cifar_10_test_dataset,
    },
    "cifar-100": {
        "train_data": cifar_100_train_dataset,
        "test_data": cifar_100_test_dataset,
    },
    "mnist": {
        "train_data": mnist_train_dataset,
        "test_data": mnist_test_dataset,
    },
    "fashion-mnist": {
        "train_data": fashion_mnist_train_dataset,
        "test_data": fashion_mnist_test_dataset,
    },
}

In [None]:
# Create mini train dataset for testing
def get_imbalanced_dataset(dataset, fractions):
    assert len(fractions) == dataset['label'].nunique()

    imbalanced_dataset = pd.DataFrame(columns=dataset.columns)
    print(imbalanced_dataset)
    for i in range(len(fractions)):
        idf = dataset[dataset['label'] == i].sample(frac=fractions[i])
        print(f'label {i} will have {idf.shape[0]} examples')
        imbalanced_dataset = pd.concat([imbalanced_dataset, idf], ignore_index=True)
    print(f'total imbalanced dataset length {imbalanced_dataset.shape[0]}')
    return imbalanced_dataset

### Uncomment below to create imbalanced datasets

# cifar_100_num_classes = len(cifar_100_train_dataset['label'].unique())
# cifar_100_distribution = [0.15] * int(cifar_100_num_classes * 0.9) + [1.] * int(cifar_100_num_classes * 0.1)
# cifar_100_train_dataset = get_imbalanced_dataset(cifar_100_train_dataset, cifar_100_distribution)
# cifar_10_train_dataset = get_imbalanced_dataset(cifar_10_train_dataset,[0.09,0.09,0.09,0.09,1.,1.,0.09,0.09,1.,1.])
# mnist_train_dataset = get_imbalanced_dataset(mnist_train_dataset,[0.09,0.09,0.09,0.09,1.,1.,0.09,0.09,1.,1.])
# fashion_mnist_train_dataset = get_imbalanced_dataset(fashion_mnist_train_dataset,[0.09,0.09,0.09,0.09,1.,1.,0.09,0.09,1.,1.])

In [None]:
# Check out a dataset
mnist_train_dataset.head()

## Train model

In [None]:
%%time

def train_ag_model(
    train_data,
    dataset_name,
    model_folder="./models/",    
    epochs=100,
    model="swin_base_patch4_window7_224",
    time_limit=10*3600
):

    # init model
    predictor = ImagePredictor(verbosity=0)

    MODEL_PARAMS = {
        "model": model,
        "epochs": epochs,
    }

    # run training
    predictor.fit(
        train_data=train_data,
        # tuning_data=,
        ngpus_per_trial=1,
        hyperparameters=MODEL_PARAMS,
        time_limit=time_limit,
        random_state=123,
    )

    # save model
    filename = f"{model_folder}{model}_{dataset_name}.ag"
    predictor.save(filename)    
    
    return predictor

## Train model for all datasets

In [None]:
model = "swin_base_patch4_window7_224"

for key, data in data_model_dict.items():

    dataset = key
    train_dataset = data["train_data"]
    
    print(f"Dataset: {dataset}")
    print(f"  Records: {train_dataset.shape}")
    print(f"  Classes: {train_dataset.label.nunique()}")    
    
    _ = train_ag_model(train_dataset, dataset_name=dataset, model=model, epochs=100)