<a href="https://colab.research.google.com/github/jonl3379/jonl3379/blob/main/20231204_cnn_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TRAIN MODELS

In [None]:
### GLOBALS ###

TRAIN = True
ENSEMBLE = True
DATE = "20231204"
DATASET = "oxford_flowers102"
# Choose from {"oxford_flowers102", "cifar100", "stanford_dogs"}
MODELS = "6xresnet"
# Choose from {"6xresnet", "6xefficientnet"}
EPOCHS = 10
# Choose from {"10", "20"}
PREFIX = DATE +"_"+ DATASET.replace('_','') +"_"+ MODELS +"_"+ str(EPOCHS) +"ep_"
print("PREFIX =", PREFIX)

### Libraries ###

import numpy as np
import pickle
import tensorflow as tf
from tensorflow.keras import layers, models
import tensorflow_datasets as tfds

### File locations ###

def fileloc():
  try:
    from google.colab import drive
    drive.mount("/content/drive")
    print("FILELOC =", FILELOC)
    return "/content/drive/MyDrive/edu/results/"
  except:
    print("FILELOC =", FILELOC)
    return "./results/"

FILELOC = fileloc()

### Save and load ###

def fixname(name):
  if name[-4:] == ".pkl":
    return name
  else:
    return name + ".pkl"

def save(data, filename, with_prefix = True):
  if with_prefix:
    full_filename = FILELOC+PREFIX+fixname(filename)
  else:
    full_filename = FILELOC+fixname(filename)
  with open(full_filename, 'wb') as outfile:
    pickle.dump(data, outfile)

def load(filename, with_prefix = True):
  if with_prefix:
    full_filename = FILELOC+PREFIX+fixname(filename)
  else:
    full_filename = FILELOC+fixname(filename)
  with open(full_filename, 'rb') as infile:
    return pickle.load(infile)

### Load dataset and normalize ###

def load_standard(dataset):
  train_dataset      = tfds.load(dataset, as_supervised=True, split="train[:80%]")
  validation_dataset = tfds.load(dataset, as_supervised=True, split="train[80%:]")
  test_dataset, info = tfds.load(dataset, as_supervised=True, split="test", with_info=True)
  label_count = info.features["label"].num_classes
  return train_dataset, validation_dataset, test_dataset, label_count

def load_oxford_flowers102(dataset):
  ### CAUTION: SWAPPED TRAIN AND TEST FOR SIZE REASONS ###
  train_dataset      = tfds.load("oxford_flowers102", as_supervised=True, split="test")
  validation_dataset = tfds.load("oxford_flowers102", as_supervised=True, split="validation")
  test_dataset, info = tfds.load("oxford_flowers102", as_supervised=True, split="train", with_info=True)
  label_count = info.features["label"].num_classes
  return train_dataset, validation_dataset, test_dataset, label_count

def preprocess_image(image, label, resolution=(224,224)):
    image = tf.image.resize(image, resolution)
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = (image - 0.5) / 0.5  # Normalize to the range [-1, 1]
    return image, label

def load(dataset):
  if dataset in ["cifar100", "stanford_dogs"]:
    train_dataset, validation_dataset, test_dataset, label_count = load_standard(dataset)
  elif dataset == "oxford_flowers102":
    train_dataset, validation_dataset, test_dataset, label_count = load_oxford_flowers102(dataset)
  test_dataset = test_dataset.map(preprocess_image).batch(32)
  train_dataset = train_dataset.map(preprocess_image).batch(32)
  validation_dataset = validation_dataset.map(preprocess_image).batch(32)
  return train_dataset, validation_dataset, test_dataset, label_count

train_dataset, validation_dataset, test_dataset, label_count = load(DATASET)
print("DATASET =", len(train_dataset), len(test_dataset), len(validation_dataset), label_count)

### Build and compile models ###

def modelmaker(modelnames):
  models = []
  print("Compiling model", end = ' ')
  for i, modelname in enumerate(modelnames):
    print(i, end = ' ')
    exec("models.append(tf.keras.applications."+modelname+"(weights=None, classes=label_count))")
    models[-1].compile(optimizer='adam',
                       loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                       metrics=['accuracy'])
  return models

if MODELS == "6xresnet":
  modelnames = ["ResNet152V2","ResNet152V2","ResNet152V2","ResNet152V2","ResNet152V2","ResNet152V2"]
elif MODELS == "6xefficientnet":
  modelnames = ["EfficientNetB0","EfficientNetB0","EfficientNetB0","EfficientNetB0","EfficientNetB0","EfficientNetB0"]
models = modelmaker(modelnames)

print(len(models), 'models compiled')



PREFIX = 20231204_oxfordflowers102_6xresnet_10ep_
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
FILELOC = /content/drive/MyDrive/edu/results/
DATASET = 193 32 32 102
Compiling model 0 1 2 3 4 5 
 6  models compiled
6 models compiled
